diff --git a/SpecForge-ext/cache/compiled_kernels/2a/c2aenxafaj3vioqyzq7mx27etpwqzasypu2acikotkgg3rec7mlw.py b/SpecForge-ext/cache/compiled_kernels/2a/c2aenxafaj3vioqyzq7mx27etpwqzasypu2acikotkgg3rec7mlw.py new file mode 100644 index 0000000000000000000000000000000000000000..29106ce49f53859749826e711cd0d561ddf4cf87 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/2a/c2aenxafaj3vioqyzq7mx27etpwqzasypu2acikotkgg3rec7mlw.py @@ -0,0 +1,47 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*i64', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(1,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_argmax_1(in_ptr0, out_ptr0, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % ks0) + x1 = xindex // ks0 + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + x3 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_2 + 32000*x0 + ks1*x1), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask & xmask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask & xmask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tl.store(out_ptr0 + (x3), tmp2, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/2c/c2cmsqbgkrofzfikzrnehvhp4wxhze4bly4ct5edlg3syiny626e.py b/SpecForge-ext/cache/compiled_kernels/2c/c2cmsqbgkrofzfikzrnehvhp4wxhze4bly4ct5edlg3syiny626e.py new file mode 100644 index 0000000000000000000000000000000000000000..ac6d438aec451660bdd65d83e22f1636c8f6e9ab --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/2c/c2cmsqbgkrofzfikzrnehvhp4wxhze4bly4ct5edlg3syiny626e.py @@ -0,0 +1,43 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 32, 'r0_': 16}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr1': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_sum_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_sum_2(in_ptr0, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp3 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp1 = tmp0.to(tl.int64) + tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) + tmp4 = _tmp3 + tmp2 + _tmp3 = tl.where(r0_mask & xmask, tmp4, _tmp3) + tmp3 = tl.sum(_tmp3, 1)[:, None] + x2 = (xindex % ks1) + x3 = xindex // ks1 + tmp5 = tmp3.to(tl.int32) + tl.store(out_ptr1 + (x2 + x3*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1)))), tmp5, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/2k/c2kz55grrshpc3qkvg6jesbu63ts5wlhkwtjukm7zkcvhkilgn76.py b/SpecForge-ext/cache/compiled_kernels/2k/c2kz55grrshpc3qkvg6jesbu63ts5wlhkwtjukm7zkcvhkilgn76.py new file mode 100644 index 0000000000000000000000000000000000000000..1dbcf7f4c1e3e48b39ecb23a0e78d13d90157d7d --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/2k/c2kz55grrshpc3qkvg6jesbu63ts5wlhkwtjukm7zkcvhkilgn76.py @@ -0,0 +1,52 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 131072, 'r0_': 128}, + reduction_hint=ReductionHint.OUTER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_mul_sum_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_mul_sum_0(in_ptr0, in_ptr1, in_ptr2, out_ptr0, ks0, ks1, ks2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x1 = xindex // ks0 + x0 = (xindex % ks0) + _tmp13 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + x3 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = r0_2 + x1*((31 + ks1*ks2) // 32) + tmp1 = ks1*ks2 + tmp2 = tmp0 < tmp1 + tmp3 = tl.load(in_ptr0 + (x0 + ks0*(((r0_2 + x1*((31 + ks1*ks2) // 32)) % (ks1*ks2)))), r0_mask & tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp4 = tl.load(in_ptr1 + (x0 + ks0*(((r0_2 + x1*((31 + ks1*ks2) // 32)) % (ks1*ks2)))), r0_mask & tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp5 = tmp4.to(tl.float32) + tmp6 = tl.load(in_ptr2 + (((r0_2 + x1*((31 + ks1*ks2) // 32)) % (ks1*ks2))), r0_mask & tmp2 & xmask, eviction_policy='evict_last', other=0.0) + tmp7 = tmp5 * tmp6 + tmp8 = tmp7.to(tl.float32) + tmp9 = tmp3 * tmp8 + tmp10 = tl.full(tmp9.shape, 0, tmp9.dtype) + tmp11 = tl.where(tmp2, tmp9, tmp10) + tmp12 = tl.broadcast_to(tmp11, [XBLOCK, R0_BLOCK]) + tmp14 = _tmp13 + tmp12 + _tmp13 = tl.where(r0_mask & xmask, tmp14, _tmp13) + tmp13 = tl.sum(_tmp13, 1)[:, None] + tl.store(out_ptr0 + (x3), tmp13, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/2v/c2vabblrjzyryauc2jram5kwgwvjexq53bdwxugagjegc2xvufuy.py b/SpecForge-ext/cache/compiled_kernels/2v/c2vabblrjzyryauc2jram5kwgwvjexq53bdwxugagjegc2xvufuy.py new file mode 100644 index 0000000000000000000000000000000000000000..f065bee52b60b37bb1b4a9128080d9f8f32e511d --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/2v/c2vabblrjzyryauc2jram5kwgwvjexq53bdwxugagjegc2xvufuy.py @@ -0,0 +1,44 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 1, 'r0_': 4096}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i64', 'in_ptr2': '*i64', 'out_ptr0': '*i64', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {'xnumel': 1}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_eq_mul_squeeze_sum_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_eq_mul_squeeze_sum_2(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 1 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + _tmp7 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_0 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.load(in_ptr1 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp4 = tl.load(in_ptr2 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp2 = tmp0 == tmp1 + tmp3 = tmp2.to(tl.int64) + tmp5 = tmp3 * tmp4 + tmp6 = tl.broadcast_to(tmp5, [XBLOCK, R0_BLOCK]) + tmp8 = _tmp7 + tmp6 + _tmp7 = tl.where(r0_mask, tmp8, _tmp7) + tmp7 = tl.sum(_tmp7, 1)[:, None] + tl.store(out_ptr0 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp7, None) diff --git a/SpecForge-ext/cache/compiled_kernels/2v/c2vbm66z3map72ysgiduadjtps3nnrhjldngw5bzue3cm5xo44w5.py b/SpecForge-ext/cache/compiled_kernels/2v/c2vbm66z3map72ysgiduadjtps3nnrhjldngw5bzue3cm5xo44w5.py new file mode 100644 index 0000000000000000000000000000000000000000..4fff621c7f17dba1208978086628fb46b9e4a93b --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/2v/c2vbm66z3map72ysgiduadjtps3nnrhjldngw5bzue3cm5xo44w5.py @@ -0,0 +1,835 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32', 'ks6': 'i32', 'ks7': 'i32', 'ks8': 'i32'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_zeros_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128*ks1, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128*ks1, 128, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128*ks1, 128, 1 + + ZQ = 8 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 8 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 8 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = ks5 + stride_q_idx_h = ks6*ks7 + stride_q_idx_n = ks6 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr16 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = ks8 + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp39 = (ds) + grad_scores = tmp39 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp40 = (qkT) + post_mod_scores = tmp40 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp41 = tl.full([1], False, tl.int1) + tmp42 = (m) + tmp43 = (n) + tmp44 = tmp42 >= tmp43 + tmp45 = tmp43.to(tl.int64) + tmp46 = (off_z) + tmp47 = tl.load(in_ptr16 + tmp46) + tmp48 = tmp45 < tmp47 + tmp49 = tmp42.to(tl.int64) + tmp50 = tmp49 < tmp47 + tmp51 = tmp48 & tmp50 + tmp52 = tmp44 & tmp51 + tmp53 = tmp41 | tmp52 + tmp54 = ks8 + tmp55 = tmp43 >= tmp54 + tmp56 = (tmp43 % tmp54) + tmp57 = tl.full([1], 0, tl.int32) + tmp58 = tmp56 != tmp57 + tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0 + tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0 + tmp61 = tmp59 != tmp60 + tmp62 = tmp58 & tmp61 + tmp63 = tmp56 + tmp54 + tmp64 = tl.where(tmp62, tmp63, tmp56) + tmp65 = tmp64.to(tl.int64) + tmp66 = tmp65 < tmp47 + tmp67 = tmp55 & tmp66 + tmp68 = tmp43 - tmp42 + tmp69 = (tmp68 % tmp54) + tmp70 = tmp69 != tmp57 + tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0 + tmp72 = tmp71 != tmp60 + tmp73 = tmp70 & tmp72 + tmp74 = tmp69 + tmp54 + tmp75 = tl.where(tmp73, tmp74, tmp69) + tmp76 = tmp75 == tmp57 + tmp77 = tmp67 & tmp76 + tmp78 = tmp53 | tmp77 + mask_mod_output = tmp78 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp79 = (dsT) + grad_scores = tmp79 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/35/c35huqp6ngzh67kt32kuxoqpghc32fstv4zogcouzabdxxwta3sl.py b/SpecForge-ext/cache/compiled_kernels/35/c35huqp6ngzh67kt32kuxoqpghc32fstv4zogcouzabdxxwta3sl.py new file mode 100644 index 0000000000000000000000000000000000000000..7874f9914c2ca3a1285d08922902064b6d9abe04 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/35/c35huqp6ngzh67kt32kuxoqpghc32fstv4zogcouzabdxxwta3sl.py @@ -0,0 +1,27 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 512}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr0': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_6', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused__to_copy_6(in_ptr0, out_ptr0, ks0, ks1, ks2, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = (xindex % ks0) + x1 = ((xindex // ks0) % ks1) + x2 = xindex // ks2 + tmp0 = tl.load(in_ptr0 + (x1 + x0*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1))) + x2*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1)))), xmask, eviction_policy='evict_last') + tmp1 = tmp0.to(tl.int32) + tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))) + x2*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1)))), tmp1, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/35/db497030eed19cbbd19ee623329ec09ad9ad496274b305a5f803696a7ce87fc1.best_config b/SpecForge-ext/cache/compiled_kernels/35/db497030eed19cbbd19ee623329ec09ad9ad496274b305a5f803696a7ce87fc1.best_config new file mode 100644 index 0000000000000000000000000000000000000000..0102fea510b9bf77ab661e714dfc816c066dc0d8 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/35/db497030eed19cbbd19ee623329ec09ad9ad496274b305a5f803696a7ce87fc1.best_config @@ -0,0 +1 @@ +{"XBLOCK": 128, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 11, "triton_cache_hash": "IK5RT3JGLTF5PMMUH32NIWB2GXNU6R6CGIZSCRHU3I65YM226KDA"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/45/c45n6n4rv3f3q66fpyq53ugyel2jmywhufx7ogpqvuyls4hiicz2.py b/SpecForge-ext/cache/compiled_kernels/45/c45n6n4rv3f3q66fpyq53ugyel2jmywhufx7ogpqvuyls4hiicz2.py new file mode 100644 index 0000000000000000000000000000000000000000..811ce6e2c3a9ffdf0bc1189b1b7eace03373a4be --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/45/c45n6n4rv3f3q66fpyq53ugyel2jmywhufx7ogpqvuyls4hiicz2.py @@ -0,0 +1,1065 @@ +# AOT ID: ['9_backward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/c4/cc4r2l3x4dfli5iih5dji2abfxoclfozqdaqfbdxtcf6lqfpqwdo.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] +# Source node to ATen node mapping: +# Graph fragment: +# %getitem : Tensor "bf16[8, 32, 2048, 128][8388608, 128, 4096, 1]cuda:3" = PlaceHolder[target=getitem] +# %tangents_1 : Tensor "bf16[8, 32, 2048, 128][8388608, 262144, 128, 1]cuda:3" = PlaceHolder[target=tangents_1] +# %buf0 : Tensor "bf16[8, 32, 2048][65536, 2048, 1]cuda:3" = PlaceHolder[target=buf0] +# %full_default : Tensor "f32[8, 32, 2048][65536, 2048, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 32, 2048], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:3, pin_memory: False}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_1, %primals_3, %primals_5, %getitem, %getitem_1, %tangents_1, %full_default, %fw_graph0, %joint_graph0, (2048, %primals_8, %primals_9, %primals_7, %primals_11, %primals_13, %primals_15, %primals_17, %primals_19, %primals_21, 128, 128, %mask_graph0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_10,)), kwargs = {}) +# return %buf0,%buf1 +triton_red_fused_zeros_0 = async_compile.triton('triton_red_fused_zeros_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 524288, 'r0_': 128}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr1': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 4194304, 'r0_': 268435456}} +) +@triton.jit +def triton_red_fused_zeros_0(in_ptr0, in_ptr1, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 524288 + r0_numel = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % 2048) + x1 = ((xindex // 2048) % 32) + x2 = xindex // 65536 + x4 = xindex + _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_3 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_3 + 128*x1 + 4096*x0 + 8388608*x2), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_3 + 128*x4), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = _tmp4 + tmp3 + _tmp4 = tl.where(r0_mask, tmp5, _tmp4) + tmp4 = tl.sum(_tmp4, 1)[:, None] + tmp6 = tmp4.to(tl.float32) + tmp7 = 0.0 + tmp8 = tmp6 - tmp7 + tl.store(out_ptr1 + (x4), tmp8, None) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/kx/ckxiuwld5taodt6aogxkojllbqa6rvgdkesruwe5ssurjxs2lpmw.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] +# Source node to ATen node mapping: +# Graph fragment: +# %primals_1 : Tensor "bf16[8, 32, 2048, 128][8388608, 128, 4096, 1]cuda:3" = PlaceHolder[target=primals_1] +# %primals_3 : Tensor "bf16[8, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:3" = PlaceHolder[target=primals_3] +# %primals_5 : Tensor "bf16[8, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:3" = PlaceHolder[target=primals_5] +# %getitem_1 : Tensor "f32[8, 32, 2048][65536, 2048, 1]cuda:3" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[8, 32, 2048][65536, 2048, 1]cuda:3" = PlaceHolder[target=buf1] +# %tangents_1 : Tensor "bf16[8, 32, 2048, 128][8388608, 262144, 128, 1]cuda:3" = PlaceHolder[target=tangents_1] +# %getitem_3 : Tensor "bf16[8, 32, 2048, 128][8388608, 128, 4096, 1]cuda:3" = PlaceHolder[target=getitem_3] +# %getitem_5 : Tensor "bf16[8, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:3" = PlaceHolder[target=getitem_5] +# %primals_9 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:3" = PlaceHolder[target=primals_9] +# %primals_7 : Tensor "i32[8, 1, 16, s72][16*s72, 16*s72, s72, 1]cuda:3" = PlaceHolder[target=primals_7] +# %primals_15 : Tensor "i32[8, 1, s56][s56, s56, 1]cuda:3" = PlaceHolder[target=primals_15] +# %primals_17 : Tensor "i32[8, 1, s84, 16][16*s84, 16*s84, 16, 1]cuda:3" = PlaceHolder[target=primals_17] +# %primals_11 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:3" = PlaceHolder[target=primals_11] +# %primals_13 : Tensor "i32[8, 1, 16, s4][16*s4, 16*s4, s4, 1]cuda:3" = PlaceHolder[target=primals_13] +# %primals_19 : Tensor "i32[8, 1, s99][s99, s99, 1]cuda:3" = PlaceHolder[target=primals_19] +# %primals_21 : Tensor "i32[8, 1, s6, 16][16*s6, 16*s6, 16, 1]cuda:3" = PlaceHolder[target=primals_21] +# %primals_10 : Tensor "i64[8][1]cuda:3" = PlaceHolder[target=primals_10] +# %full_default : Tensor "f32[8, 32, 2048][65536, 2048, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 32, 2048], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:3, pin_memory: False}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_1, %primals_3, %primals_5, %getitem, %getitem_1, %tangents_1, %full_default, %fw_graph0, %joint_graph0, (2048, %primals_8, %primals_9, %primals_7, %primals_11, %primals_13, %primals_15, %primals_17, %primals_19, %primals_21, 128, 128, %mask_graph0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_10,)), kwargs = {}) +# return %getitem_4 +triton_tem_fused_zeros_1 = async_compile.triton('triton_tem_fused_zeros_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_zeros_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks0, 128*ks0, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks0, 128*ks0, 128, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 8388608, 262144, 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 8388608, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks0, 128*ks0, 128, 1 + + ZQ = 8 + HQ = 32 + HKV = 8 + Q_LEN = 2048 + ZKV = 8 + KV_LEN = ks0 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 8 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 16*ks1 + stride_kv_idx_m = ks1 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = ks2 + stride_q_idx_h = 16*ks3 + stride_q_idx_n = 16 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks0 + 1024*off_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = ks0 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr16 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp39 = (ds) + grad_scores = tmp39 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = ks0 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp40 = (qkT) + post_mod_scores = tmp40 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp41 = tl.full([1], False, tl.int1) + tmp42 = (m) + tmp43 = (n) + tmp44 = tmp42 >= tmp43 + tmp45 = tmp43.to(tl.int64) + tmp46 = (off_z) + tmp47 = tl.load(in_ptr16 + tmp46) + tmp48 = tmp45 < tmp47 + tmp49 = tmp42.to(tl.int64) + tmp50 = tmp49 < tmp47 + tmp51 = tmp48 & tmp50 + tmp52 = tmp44 & tmp51 + tmp53 = tmp41 | tmp52 + tmp54 = tl.full([1], 2048, tl.int32) + tmp55 = tmp43 >= tmp54 + tmp56 = (tmp43 % tmp54) + tmp57 = tl.full([1], 0, tl.int32) + tmp58 = tmp56 != tmp57 + tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0 + tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0 + tmp61 = tmp59 != tmp60 + tmp62 = tmp58 & tmp61 + tmp63 = tmp56 + tmp54 + tmp64 = tl.where(tmp62, tmp63, tmp56) + tmp65 = tmp64.to(tl.int64) + tmp66 = tmp65 < tmp47 + tmp67 = tmp55 & tmp66 + tmp68 = tmp43 - tmp42 + tmp69 = (tmp68 % tmp54) + tmp70 = tmp69 != tmp57 + tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0 + tmp72 = tmp71 != tmp60 + tmp73 = tmp70 & tmp72 + tmp74 = tmp69 + tmp54 + tmp75 = tl.where(tmp73, tmp74, tmp69) + tmp76 = tmp75 == tmp57 + tmp77 = tmp67 & tmp76 + tmp78 = tmp53 | tmp77 + mask_mod_output = tmp78 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp79 = (dsT) + grad_scores = tmp79 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_8, primals_6, primals_12, primals_14, primals_16, primals_18, primals_20, primals_1, primals_3, primals_5, primals_7, primals_9, primals_10, primals_11, primals_13, primals_15, primals_17, primals_19, primals_21, getitem, getitem_1, tangents_1 = args + args.clear() + s0 = primals_8 + s72 = primals_6 + s4 = primals_12 + s56 = primals_14 + s84 = primals_16 + s99 = primals_18 + s6 = primals_20 + assert_size_stride(primals_1, (8, 32, 2048, 128), (8388608, 128, 4096, 1)) + assert_size_stride(primals_3, (8, 8, s0, 128), (1024*s0, 128*s0, 128, 1)) + assert_size_stride(primals_5, (8, 8, s0, 128), (1024*s0, 128*s0, 128, 1)) + assert_size_stride(primals_7, (8, 1, 16, s72), (16*s72, 16*s72, s72, 1)) + assert_size_stride(primals_9, (8, 1, 16), (16, 16, 1)) + assert_size_stride(primals_10, (8, ), (1, )) + assert_size_stride(primals_11, (8, 1, 16), (16, 16, 1)) + assert_size_stride(primals_13, (8, 1, 16, s4), (16*s4, 16*s4, s4, 1)) + assert_size_stride(primals_15, (8, 1, s56), (s56, s56, 1)) + assert_size_stride(primals_17, (8, 1, s84, 16), (16*s84, 16*s84, 16, 1)) + assert_size_stride(primals_19, (8, 1, s99), (s99, s99, 1)) + assert_size_stride(primals_21, (8, 1, s6, 16), (16*s6, 16*s6, 16, 1)) + assert_size_stride(getitem, (8, 32, 2048, 128), (8388608, 128, 4096, 1)) + assert_size_stride(getitem_1, (8, 32, 2048), (65536, 2048, 1)) + assert_size_stride(tangents_1, (8, 32, 2048, 128), (8388608, 262144, 128, 1)) + with torch.cuda._DeviceGuard(3): + torch.cuda.set_device(3) + buf1 = empty_strided_cuda((8, 32, 2048), (65536, 2048, 1), torch.float32) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] + stream3 = get_raw_stream(3) + triton_red_fused_zeros_0.run(getitem, tangents_1, buf1, 524288, 128, stream=stream3) + del getitem + buf3 = empty_strided_cuda((8, 32, 2048, 128), (8388608, 128, 4096, 1), torch.bfloat16) + buf4 = empty_strided_cuda((8, 8, s0, 128), (1024*s0, 128*s0, 128, 1), torch.bfloat16) + buf5 = empty_strided_cuda((8, 8, s0, 128), (1024*s0, 128*s0, 128, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] + stream3 = get_raw_stream(3) + triton_tem_fused_zeros_1.run(primals_1, primals_3, primals_5, getitem_1, buf1, tangents_1, buf3, buf4, primals_9, primals_7, primals_15, primals_17, primals_11, primals_13, primals_19, primals_21, primals_10, buf5, s0, s72, s56, s84, 64 + ((127 + s0) // 128), 8, 8, stream=stream3) + del buf1 + del getitem_1 + del primals_1 + del primals_10 + del primals_11 + del primals_13 + del primals_15 + del primals_17 + del primals_19 + del primals_21 + del primals_3 + del primals_5 + del primals_7 + del primals_9 + del tangents_1 + return (buf3, None, buf5, None, buf4, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_8 = 4096 + primals_6 = 32 + primals_12 = 32 + primals_14 = 32 + primals_16 = 32 + primals_18 = 32 + primals_20 = 32 + primals_1 = rand_strided((8, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:3', dtype=torch.bfloat16) + primals_3 = rand_strided((8, 8, 4096, 128), (4194304, 524288, 128, 1), device='cuda:3', dtype=torch.bfloat16) + primals_5 = rand_strided((8, 8, 4096, 128), (4194304, 524288, 128, 1), device='cuda:3', dtype=torch.bfloat16) + primals_7 = rand_strided((8, 1, 16, 32), (512, 512, 32, 1), device='cuda:3', dtype=torch.int32) + primals_9 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:3', dtype=torch.int32) + primals_10 = rand_strided((8, ), (1, ), device='cuda:3', dtype=torch.int64) + primals_11 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:3', dtype=torch.int32) + primals_13 = rand_strided((8, 1, 16, 32), (512, 512, 32, 1), device='cuda:3', dtype=torch.int32) + primals_15 = rand_strided((8, 1, 32), (32, 32, 1), device='cuda:3', dtype=torch.int32) + primals_17 = rand_strided((8, 1, 32, 16), (512, 512, 16, 1), device='cuda:3', dtype=torch.int32) + primals_19 = rand_strided((8, 1, 32), (32, 32, 1), device='cuda:3', dtype=torch.int32) + primals_21 = rand_strided((8, 1, 32, 16), (512, 512, 16, 1), device='cuda:3', dtype=torch.int32) + getitem = rand_strided((8, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:3', dtype=torch.bfloat16) + getitem_1 = rand_strided((8, 32, 2048), (65536, 2048, 1), device='cuda:3', dtype=torch.float32) + tangents_1 = rand_strided((8, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:3', dtype=torch.bfloat16) + fn = lambda: call([primals_8, primals_6, primals_12, primals_14, primals_16, primals_18, primals_20, primals_1, primals_3, primals_5, primals_7, primals_9, primals_10, primals_11, primals_13, primals_15, primals_17, primals_19, primals_21, getitem, getitem_1, tangents_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/45/c45wcuv4sn2ie6lowss5cksehnjgehlinebmvyopum4so5p257dk.py b/SpecForge-ext/cache/compiled_kernels/45/c45wcuv4sn2ie6lowss5cksehnjgehlinebmvyopum4so5p257dk.py new file mode 100644 index 0000000000000000000000000000000000000000..9db303e0de9fc070a47aa240907e2e20b126da5e --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/45/c45wcuv4sn2ie6lowss5cksehnjgehlinebmvyopum4so5p257dk.py @@ -0,0 +1,835 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32', 'ks6': 'i32', 'ks7': 'i32', 'ks8': 'i32'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_zeros_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128*ks1, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128*ks1, 128, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128*ks1, 128, 1 + + ZQ = 2 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 2 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 2 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = ks5 + stride_q_idx_h = ks6*ks7 + stride_q_idx_n = ks6 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr16 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = ks8 + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp39 = (ds) + grad_scores = tmp39 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp40 = (qkT) + post_mod_scores = tmp40 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp41 = tl.full([1], False, tl.int1) + tmp42 = (m) + tmp43 = (n) + tmp44 = tmp42 >= tmp43 + tmp45 = tmp43.to(tl.int64) + tmp46 = (off_z) + tmp47 = tl.load(in_ptr16 + tmp46) + tmp48 = tmp45 < tmp47 + tmp49 = tmp42.to(tl.int64) + tmp50 = tmp49 < tmp47 + tmp51 = tmp48 & tmp50 + tmp52 = tmp44 & tmp51 + tmp53 = tmp41 | tmp52 + tmp54 = ks8 + tmp55 = tmp43 >= tmp54 + tmp56 = (tmp43 % tmp54) + tmp57 = tl.full([1], 0, tl.int32) + tmp58 = tmp56 != tmp57 + tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0 + tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0 + tmp61 = tmp59 != tmp60 + tmp62 = tmp58 & tmp61 + tmp63 = tmp56 + tmp54 + tmp64 = tl.where(tmp62, tmp63, tmp56) + tmp65 = tmp64.to(tl.int64) + tmp66 = tmp65 < tmp47 + tmp67 = tmp55 & tmp66 + tmp68 = tmp43 - tmp42 + tmp69 = (tmp68 % tmp54) + tmp70 = tmp69 != tmp57 + tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0 + tmp72 = tmp71 != tmp60 + tmp73 = tmp70 & tmp72 + tmp74 = tmp69 + tmp54 + tmp75 = tl.where(tmp73, tmp74, tmp69) + tmp76 = tmp75 == tmp57 + tmp77 = tmp67 & tmp76 + tmp78 = tmp53 | tmp77 + mask_mod_output = tmp78 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp79 = (dsT) + grad_scores = tmp79 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) diff --git a/SpecForge-ext/cache/compiled_kernels/4h/9e00c3cbd4f3ffea506c2d972effa4f5d1a03b1819fbd2068ce6d04ad21a37d7.best_config b/SpecForge-ext/cache/compiled_kernels/4h/9e00c3cbd4f3ffea506c2d972effa4f5d1a03b1819fbd2068ce6d04ad21a37d7.best_config new file mode 100644 index 0000000000000000000000000000000000000000..a570e8d663ff6e600f50df05a811c859065ec3c4 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/4h/9e00c3cbd4f3ffea506c2d972effa4f5d1a03b1819fbd2068ce6d04ad21a37d7.best_config @@ -0,0 +1 @@ +{"XBLOCK": 512, "num_warps": 8, "num_stages": 1, "configs_hash": "3ca5c3e34d35093f3c9ab2829a9faeebad5e61c4ca13d5ed6053d7b71ce60d5a", "found_by_coordesc": false, "time_taken_ms": 21, "triton_cache_hash": "Z2RWAHMO7VUWQKIIRA5A46JYV2SEXHWLKREQM7TOP6VGUWDXAYAQ"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/4h/c4hrpftpfto2n4yelfxmq5tawsfst2z5xq7othxvdoymqaudsvcw.py b/SpecForge-ext/cache/compiled_kernels/4h/c4hrpftpfto2n4yelfxmq5tawsfst2z5xq7othxvdoymqaudsvcw.py new file mode 100644 index 0000000000000000000000000000000000000000..68c3d29233ea69dd6f320cbd46b2ca9bb269397a --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/4h/c4hrpftpfto2n4yelfxmq5tawsfst2z5xq7othxvdoymqaudsvcw.py @@ -0,0 +1,56 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 4194304}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'ks4': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, ks4, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x4 = xindex + x2 = ((xindex // ks0) % ks1) + x0 = (xindex % ks3) + x5 = xindex // ks3 + tmp0 = tl.load(in_ptr0 + (x4), xmask, eviction_policy='evict_last').to(tl.float32) + tmp1 = tl.load(in_ptr1 + (x2), xmask, eviction_policy='evict_last') + tmp2 = ks2 + tmp3 = tmp1 + tmp2 + tmp4 = tmp1 < 0 + tmp5 = tl.where(tmp4, tmp3, tmp1) + tl.device_assert(((0 <= tmp5) & (tmp5 < ks2)) | ~(xmask), "index out of bounds: 0 <= tmp5 < ks2") + tmp7 = tl.load(in_ptr2 + (x0 + ks3*tmp5), xmask, eviction_policy='evict_last').to(tl.float32) + tmp8 = tmp0 * tmp7 + tmp9 = x0 + tmp10 = tl.full([1], 0, tl.int64) + tmp11 = tmp9 >= tmp10 + tmp12 = ks3 + (-1)*(ks3 // 2) + tmp13 = tmp9 < tmp12 + tmp14 = tl.load(in_ptr0 + (ks3*x5 + (ks3 // 2) + (x0)), tmp13 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp15 = -tmp14 + tmp16 = tl.full(tmp15.shape, 0.0, tmp15.dtype) + tmp17 = tl.where(tmp13, tmp15, tmp16) + tmp18 = tmp9 >= tmp12 + tmp19 = ks3 + tmp20 = tmp9 < tmp19 + tmp21 = tl.load(in_ptr0 + (ks3*x5 + (x0 + ((-1)*ks3) + (ks3 // 2))), tmp18 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp22 = tl.where(tmp13, tmp17, tmp21) + tmp23 = ks4 + tmp24 = tmp1 + tmp23 + tmp25 = tl.where(tmp4, tmp24, tmp1) + tl.device_assert(((0 <= tmp25) & (tmp25 < ks4)) | ~(xmask), "index out of bounds: 0 <= tmp25 < ks4") + tmp27 = tl.load(in_ptr3 + (x0 + ks3*tmp25), xmask, eviction_policy='evict_last').to(tl.float32) + tmp28 = tmp22 * tmp27 + tmp29 = tmp8 + tmp28 + tl.store(out_ptr0 + (x4), tmp29, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/4k/30a0e09dbdf44769796e9e261da2a9dcbfc798ae7811e19f9adc033f960f3fae.best_config b/SpecForge-ext/cache/compiled_kernels/4k/30a0e09dbdf44769796e9e261da2a9dcbfc798ae7811e19f9adc033f960f3fae.best_config new file mode 100644 index 0000000000000000000000000000000000000000..73d39cec03a4913ffd38deb7ad038bf56b5cd33f --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/4k/30a0e09dbdf44769796e9e261da2a9dcbfc798ae7811e19f9adc033f960f3fae.best_config @@ -0,0 +1 @@ +{"XBLOCK": 1, "num_warps": 2, "num_stages": 1, "configs_hash": "b6ac5ef64fddcad8fc8d2c05fa12424871fd9baa5a4158ff38ecebbafb55a4b1", "found_by_coordesc": false, "time_taken_ms": 25, "triton_cache_hash": "G2LU7LIHIOEHQSWVLFBJATACJ76YHM672CUBUDGJGAJUEQVWVOFQ"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/4k/c4kzcehfveyvvtlnmx5jh5naezqnmtz2ubxuawsucb27r43j5yfa.py b/SpecForge-ext/cache/compiled_kernels/4k/c4kzcehfveyvvtlnmx5jh5naezqnmtz2ubxuawsucb27r43j5yfa.py new file mode 100644 index 0000000000000000000000000000000000000000..b82069d3f50671b6eef10f355f30201954266f64 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/4k/c4kzcehfveyvvtlnmx5jh5naezqnmtz2ubxuawsucb27r43j5yfa.py @@ -0,0 +1,49 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 256, 'r0_': 16}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr2': '*i32', 'out_ptr3': '*i32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3(in_ptr0, out_ptr2, out_ptr3, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 16 + R0_BLOCK: tl.constexpr = 16 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + x0 = (xindex % ks0) + x1 = xindex // ks0 + x3 = xindex + tmp0 = tl.load(in_ptr0 + (r0_2 + x0 + 16*x1 + ks0*r0_2 + 16*ks0*x1), xmask, eviction_policy='evict_last', other=0.0) + tmp1 = r0_2 + tmp2 = tmp1.to(tl.int16) + tmp3 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp4 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5, tmp6, = triton_helpers.sort_with_index(tmp3, tmp4, None, 1, stable=True, descending=True) + tmp7 = tmp0.to(tl.int64) + tmp8 = tl.broadcast_to(tmp7, [XBLOCK, R0_BLOCK]) + tmp10 = tl.where(xmask, tmp8, 0) + tmp11 = tl.sum(tmp10, 1)[:, None].to(tl.int64) + tmp12 = tmp6.to(tl.int64) + tmp13 = tmp12.to(tl.int32) + tmp14 = tmp11.to(tl.int32) + tl.store(out_ptr2 + (r0_2 + 16*x0 + 16*x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp13, xmask) + tl.store(out_ptr3 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp14, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/4m/c4mv34wib446qhr7sd5yhgc4mdneb7isnb6uitnbwvdgrbpgyf2s.py b/SpecForge-ext/cache/compiled_kernels/4m/c4mv34wib446qhr7sd5yhgc4mdneb7isnb6uitnbwvdgrbpgyf2s.py new file mode 100644 index 0000000000000000000000000000000000000000..f67648eb8347be4166e5541083f2fa19de5f9453 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/4m/c4mv34wib446qhr7sd5yhgc4mdneb7isnb6uitnbwvdgrbpgyf2s.py @@ -0,0 +1,552 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'in_ptr9': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128*ks1, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128*ks1, 128, 1 + + ZQ = 2 + HQ = 32 + Q_LEN = ks0 + ZKV = 2 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 2 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m + 4096*idx_zq*ks0, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr9 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = ks5 + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/4u/c4uf4o6eypfpqr4isgii4opqr5i3brobwecljte7sqvztk2kyafz.py b/SpecForge-ext/cache/compiled_kernels/4u/c4uf4o6eypfpqr4isgii4opqr5i3brobwecljte7sqvztk2kyafz.py new file mode 100644 index 0000000000000000000000000000000000000000..5e6909e55d8bc822ed80ce8b2719d42655dd65da --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/4u/c4uf4o6eypfpqr4isgii4opqr5i3brobwecljte7sqvztk2kyafz.py @@ -0,0 +1,552 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'in_ptr9': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks0, 128*ks0, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks0, 128*ks0, 128, 1 + + ZQ = 2 + HQ = 32 + Q_LEN = 2048 + ZKV = 2 + KV_LEN = ks0 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 2 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 16*ks1 + stride_kv_idx_m = ks1 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 262144*idx_hq + 8388608*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m + 8388608*idx_zq, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr9 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/4u/c4uhrh7gjsy72in52pmmkpoiwetwjbked3nkrbcotbo4sj5bq7bi.py b/SpecForge-ext/cache/compiled_kernels/4u/c4uhrh7gjsy72in52pmmkpoiwetwjbked3nkrbcotbo4sj5bq7bi.py new file mode 100644 index 0000000000000000000000000000000000000000..e20f6ceaf81c1a35b54840fb40764d79a923ebfc --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/4u/c4uhrh7gjsy72in52pmmkpoiwetwjbked3nkrbcotbo4sj5bq7bi.py @@ -0,0 +1,835 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention_backward(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks0, 128*ks0, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks0, 128*ks0, 128, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 8388608, 262144, 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 8388608, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks0, 128*ks0, 128, 1 + + ZQ = 8 + HQ = 32 + HKV = 8 + Q_LEN = 2048 + ZKV = 8 + KV_LEN = ks0 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 8 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 16*ks1 + stride_kv_idx_m = ks1 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = ks2 + stride_q_idx_h = 16*ks3 + stride_q_idx_n = 16 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks0 + 1024*off_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = ks0 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr16 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp39 = (ds) + grad_scores = tmp39 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = ks0 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp40 = (qkT) + post_mod_scores = tmp40 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp41 = tl.full([1], False, tl.int1) + tmp42 = (m) + tmp43 = (n) + tmp44 = tmp42 >= tmp43 + tmp45 = tmp43.to(tl.int64) + tmp46 = (off_z) + tmp47 = tl.load(in_ptr16 + tmp46) + tmp48 = tmp45 < tmp47 + tmp49 = tmp42.to(tl.int64) + tmp50 = tmp49 < tmp47 + tmp51 = tmp48 & tmp50 + tmp52 = tmp44 & tmp51 + tmp53 = tmp41 | tmp52 + tmp54 = tl.full([1], 2048, tl.int32) + tmp55 = tmp43 >= tmp54 + tmp56 = (tmp43 % tmp54) + tmp57 = tl.full([1], 0, tl.int32) + tmp58 = tmp56 != tmp57 + tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0 + tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0 + tmp61 = tmp59 != tmp60 + tmp62 = tmp58 & tmp61 + tmp63 = tmp56 + tmp54 + tmp64 = tl.where(tmp62, tmp63, tmp56) + tmp65 = tmp64.to(tl.int64) + tmp66 = tmp65 < tmp47 + tmp67 = tmp55 & tmp66 + tmp68 = tmp43 - tmp42 + tmp69 = (tmp68 % tmp54) + tmp70 = tmp69 != tmp57 + tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0 + tmp72 = tmp71 != tmp60 + tmp73 = tmp70 & tmp72 + tmp74 = tmp69 + tmp54 + tmp75 = tl.where(tmp73, tmp74, tmp69) + tmp76 = tmp75 == tmp57 + tmp77 = tmp67 & tmp76 + tmp78 = tmp53 | tmp77 + mask_mod_output = tmp78 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp79 = (dsT) + grad_scores = tmp79 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/4x/c4xykt7eysbenti5r55drq4w7k6c7fih4ifrou2alyqcn6r5enon.py b/SpecForge-ext/cache/compiled_kernels/4x/c4xykt7eysbenti5r55drq4w7k6c7fih4ifrou2alyqcn6r5enon.py new file mode 100644 index 0000000000000000000000000000000000000000..b0463e5141eb210d874128e3eb9119e5409c8ff9 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/4x/c4xykt7eysbenti5r55drq4w7k6c7fih4ifrou2alyqcn6r5enon.py @@ -0,0 +1,835 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_zeros_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 2097152, 262144, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 2097152, 262144, 128, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 8388608, 262144, 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 8388608, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 2097152, 262144, 128, 1 + + ZQ = 2 + HQ = 32 + HKV = 8 + Q_LEN = 2048 + ZKV = 2 + KV_LEN = 2048 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 2 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 256 + stride_kv_idx_m = 16 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 16 + stride_q_idx_h = 256 + stride_q_idx_n = 16 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 262144*off_hkv + 2097152*off_zq + tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = 2048 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr16 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp39 = (ds) + grad_scores = tmp39 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = 2048 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp40 = (qkT) + post_mod_scores = tmp40 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp41 = tl.full([1], False, tl.int1) + tmp42 = (m) + tmp43 = (n) + tmp44 = tmp42 >= tmp43 + tmp45 = tmp43.to(tl.int64) + tmp46 = (off_z) + tmp47 = tl.load(in_ptr16 + tmp46) + tmp48 = tmp45 < tmp47 + tmp49 = tmp42.to(tl.int64) + tmp50 = tmp49 < tmp47 + tmp51 = tmp48 & tmp50 + tmp52 = tmp44 & tmp51 + tmp53 = tmp41 | tmp52 + tmp54 = tl.full([1], 2048, tl.int32) + tmp55 = tmp43 >= tmp54 + tmp56 = (tmp43 % tmp54) + tmp57 = tl.full([1], 0, tl.int32) + tmp58 = tmp56 != tmp57 + tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0 + tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0 + tmp61 = tmp59 != tmp60 + tmp62 = tmp58 & tmp61 + tmp63 = tmp56 + tmp54 + tmp64 = tl.where(tmp62, tmp63, tmp56) + tmp65 = tmp64.to(tl.int64) + tmp66 = tmp65 < tmp47 + tmp67 = tmp55 & tmp66 + tmp68 = tmp43 - tmp42 + tmp69 = (tmp68 % tmp54) + tmp70 = tmp69 != tmp57 + tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0 + tmp72 = tmp71 != tmp60 + tmp73 = tmp70 & tmp72 + tmp74 = tmp69 + tmp54 + tmp75 = tl.where(tmp73, tmp74, tmp69) + tmp76 = tmp75 == tmp57 + tmp77 = tmp67 & tmp76 + tmp78 = tmp53 | tmp77 + mask_mod_output = tmp78 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp79 = (dsT) + grad_scores = tmp79 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/54/c5464ptly4n22voq77yo3wrltmxhbase2ojnypkgcpcxg6js4oty.py b/SpecForge-ext/cache/compiled_kernels/54/c5464ptly4n22voq77yo3wrltmxhbase2ojnypkgcpcxg6js4oty.py new file mode 100644 index 0000000000000000000000000000000000000000..e2dff8879e79e11f724ac14647d83c4e5b3b8d0b --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/54/c5464ptly4n22voq77yo3wrltmxhbase2ojnypkgcpcxg6js4oty.py @@ -0,0 +1,46 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 65536, 'r0_': 262144000}} +) +@triton.jit +def triton_red_fused_argmax_0(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 4096 + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tl.store(out_ptr0 + (x0), tmp2, None) diff --git a/SpecForge-ext/cache/compiled_kernels/54/c54p5bozrk7z3jkhpl6meytxfu7bz7ojmkijrdgczbq55oalwpgl.py b/SpecForge-ext/cache/compiled_kernels/54/c54p5bozrk7z3jkhpl6meytxfu7bz7ojmkijrdgczbq55oalwpgl.py new file mode 100644 index 0000000000000000000000000000000000000000..31ab417792e26b0190e2850b104b0e9edb5e3cb0 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/54/c54p5bozrk7z3jkhpl6meytxfu7bz7ojmkijrdgczbq55oalwpgl.py @@ -0,0 +1,552 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'in_ptr9': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks0, 128*ks0, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks0, 128*ks0, 128, 1 + + ZQ = 8 + HQ = 32 + Q_LEN = 2048 + ZKV = 8 + KV_LEN = ks0 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 8 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 16*ks1 + stride_kv_idx_m = ks1 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 262144*idx_hq + 8388608*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m + 8388608*idx_zq, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr9 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/5h/c5h6tol66uk77tfumu3xd25ecbr6kkxkqgk3zbmjpk4tc6sikmjb.py b/SpecForge-ext/cache/compiled_kernels/5h/c5h6tol66uk77tfumu3xd25ecbr6kkxkqgk3zbmjpk4tc6sikmjb.py new file mode 100644 index 0000000000000000000000000000000000000000..be064806fe03cc14ee6de4c783901101c59bc96c --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/5h/c5h6tol66uk77tfumu3xd25ecbr6kkxkqgk3zbmjpk4tc6sikmjb.py @@ -0,0 +1,543 @@ +# AOT ID: ['5_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/jh/cjhd7kndnunfa7ikwg3gxzzxuods7fnn5vlwqbhjxnla3dldi6sq.py +# Topologically Sorted Source Nodes: [result_1, m, causal_mask, n, b, index, lt, padding_mask, index_1, lt_1, and_2, suffix_mask, remainder, index_2, padding_mask_1, and_3, and_4, sub, remainder_1, diagnol_mask, result_2, batched_outputs_2, mask_2, mask_3, mask_block_sum], Original ATen: [aten.view, aten.arange, aten.ge, aten.index, aten.lt, aten.bitwise_and, aten.bitwise_or, aten.remainder, aten.sub, aten.eq, aten.permute, aten.sum] +# Source node to ATen node mapping: +# and_2 => bitwise_and_1 +# and_3 => bitwise_and_2 +# and_4 => bitwise_and_3, view_8 +# b => iota +# batched_outputs_2 => view_9 +# causal_mask => ge, view +# diagnol_mask => eq +# index => index +# index_1 => index_1 +# index_2 => index_2 +# lt => lt, view_1 +# lt_1 => lt_1, view_2 +# m => iota_2 +# mask_2 => view_10 +# mask_3 => permute +# mask_block_sum => sum_1 +# n => iota_3 +# padding_mask => bitwise_and, view_3, view_4 +# padding_mask_1 => lt_2, view_6 +# remainder => remainder +# remainder_1 => remainder_1 +# result_1 => bitwise_or, full_default +# result_2 => bitwise_or_1 +# sub => sub, view_7 +# suffix_mask => ge_1 +# Graph fragment: +# %arg0_1 : Tensor "i64[8][1]cuda:1" = PlaceHolder[target=arg0_1] +# %full_default : Tensor "b8[8, 1, 1][1, 1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 1, 1], False), kwargs = {dtype: torch.bool, layout: torch.strided, device: cuda:1, pin_memory: False}) +# %iota_2 : Tensor "i64[2048][1]cuda:1"[num_users=3] = call_function[target=torch.ops.prims.iota.default](args = (2048,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:1, requires_grad: False}) +# %view : Tensor "i64[2048, 1][1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%iota_2, [2048, 1]), kwargs = {}) +# %iota_3 : Tensor "i64[2048][1]cuda:1"[num_users=5] = call_function[target=torch.ops.prims.iota.default](args = (2048,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:1, requires_grad: False}) +# %ge : Tensor "b8[2048, 2048][2048, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.ge.Tensor](args = (%view, %iota_3), kwargs = {}) +# %iota : Tensor "i64[8][1]cuda:1"[num_users=3] = call_function[target=torch.ops.prims.iota.default](args = (8,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:1, requires_grad: False}) +# %index : Tensor "i64[8][1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg0_1, [%iota]), kwargs = {}) +# %view_1 : Tensor "i64[8, 1][1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%index, [8, 1]), kwargs = {}) +# %lt : Tensor "b8[8, 2048][2048, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_3, %view_1), kwargs = {}) +# %view_4 : Tensor "b8[8, 1, 2048][2048, 2048, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%lt, [8, 1, 2048]), kwargs = {}) +# %index_1 : Tensor "i64[8][1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg0_1, [%iota]), kwargs = {}) +# %view_2 : Tensor "i64[8, 1][1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%index_1, [8, 1]), kwargs = {}) +# %lt_1 : Tensor "b8[8, 2048][2048, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_2, %view_2), kwargs = {}) +# %view_3 : Tensor "b8[8, 2048, 1][2048, 1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%lt_1, [8, 2048, 1]), kwargs = {}) +# %bitwise_and : Tensor "b8[8, 2048, 2048][4194304, 2048, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%view_4, %view_3), kwargs = {}) +# %bitwise_and_1 : Tensor "b8[8, 2048, 2048][4194304, 2048, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%ge, %bitwise_and), kwargs = {}) +# %bitwise_or : Tensor "b8[8, 2048, 2048][4194304, 2048, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.bitwise_or.Tensor](args = (%full_default, %bitwise_and_1), kwargs = {}) +# %ge_1 : Tensor "b8[2048][1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.ge.Scalar](args = (%iota_3, 2048), kwargs = {}) +# %remainder : Tensor "i64[2048][1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.remainder.Scalar](args = (%iota_3, 2048), kwargs = {}) +# %index_2 : Tensor "i64[8][1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg0_1, [%iota]), kwargs = {}) +# %view_6 : Tensor "i64[8, 1][1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%index_2, [8, 1]), kwargs = {}) +# %lt_2 : Tensor "b8[8, 2048][2048, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%remainder, %view_6), kwargs = {}) +# %bitwise_and_2 : Tensor "b8[8, 2048][2048, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%ge_1, %lt_2), kwargs = {}) +# %view_8 : Tensor "b8[8, 1, 2048][2048, 2048, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%bitwise_and_2, [8, 1, 2048]), kwargs = {}) +# %view_7 : Tensor "i64[2048, 1][1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%iota_2, [2048, 1]), kwargs = {}) +# %sub : Tensor "i64[2048, 2048][2048, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%iota_3, %view_7), kwargs = {}) +# %remainder_1 : Tensor "i64[2048, 2048][2048, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.remainder.Scalar](args = (%sub, 2048), kwargs = {}) +# %eq : Tensor "b8[2048, 2048][2048, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%remainder_1, 0), kwargs = {}) +# %bitwise_and_3 : Tensor "b8[8, 2048, 2048][4194304, 2048, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%view_8, %eq), kwargs = {}) +# %bitwise_or_1 : Tensor "b8[8, 2048, 2048][4194304, 2048, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.bitwise_or.Tensor](args = (%bitwise_or, %bitwise_and_3), kwargs = {}) +# %view_9 : Tensor "b8[8, 1, 2048, 2048][4194304, 4194304, 2048, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%bitwise_or_1, [8, 1, 2048, 2048]), kwargs = {}) +# %view_10 : Tensor "b8[8, 1, 16, 128, 16, 128][4194304, 4194304, 262144, 2048, 128, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%expand, [8, 1, 16, 128, 16, 128]), kwargs = {}) +# %permute : Tensor "b8[8, 1, 16, 16, 128, 128][4194304, 4194304, 262144, 128, 2048, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_10, [0, 1, 2, 4, 3, 5]), kwargs = {}) +# %sum_1 : Tensor "i64[8, 1, 16, 16][256, 256, 16, 1]cuda:1"[num_users=3] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%permute, [-2, -1]), kwargs = {}) +# return %sum_1 +triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0 = async_compile.triton('triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 2048, 'r0_': 16384}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 32768, 'r0_': 0}} +) +@triton.jit +def triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 2048 + r0_numel = 16384 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x1 = ((xindex // 16) % 16) + x0 = (xindex % 16) + x2 = xindex // 256 + tmp3 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last') + _tmp29 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + x6 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_4 = r0_index // 128 + r0_3 = (r0_index % 128) + tmp0 = r0_4 + 128*x1 + tmp1 = r0_3 + 128*x0 + tmp2 = tmp0 >= tmp1 + tmp4 = tmp1 < tmp3 + tmp5 = tmp0 < tmp3 + tmp6 = tmp4 & tmp5 + tmp7 = tmp2 & tmp6 + tmp8 = tl.full([1, 1], False, tl.int1) + tmp9 = tmp8 | tmp7 + tmp10 = tl.full([1, 1], 2048, tl.int64) + tmp11 = tmp1 >= tmp10 + tmp12 = tmp11 & tmp4 + tmp13 = r0_3 + ((-1)*r0_4) + ((-128)*x1) + 128*x0 + tmp14 = (tmp13 % tmp10) + tmp15 = tl.full([1, 1], 0, tl.int32) + tmp16 = tmp14 != tmp15 + tmp17 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp18 = (libdevice.signbit(tmp10) != 0) if (tmp10).dtype is tl.float32 else tmp10 < 0 + tmp19 = tmp17 != tmp18 + tmp20 = tmp16 & tmp19 + tmp21 = tmp14 + tmp10 + tmp22 = tl.where(tmp20, tmp21, tmp14) + tmp23 = tl.full([1, 1], 0, tl.int64) + tmp24 = tmp22 == tmp23 + tmp25 = tmp12 & tmp24 + tmp26 = tmp9 | tmp25 + tmp27 = tmp26.to(tl.int64) + tmp28 = tl.broadcast_to(tmp27, [XBLOCK, R0_BLOCK]) + tmp30 = _tmp29 + tmp28 + _tmp29 = tl.where(r0_mask & xmask, tmp30, _tmp29) + tmp29 = tl.sum(_tmp29, 1)[:, None] + tl.store(out_ptr0 + (x6), tmp29, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/6f/c6fuhct5vdp3d5lx45chz27ghag5dfreh2h3hbzxl5elhim3qhpx.py +# Topologically Sorted Source Nodes: [dense_mask_4], Original ATen: [aten.new_zeros] +# Source node to ATen node mapping: +# dense_mask_4 => full_default_4 +# Graph fragment: +# %full_default_4 : Tensor "i32[8, 1, 16, 17][272, 272, 17, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 1, 16, 17], 0), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:1, pin_memory: False}) +# return %index_put_1 +triton_poi_fused_new_zeros_1 = async_compile.triton('triton_poi_fused_new_zeros_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 4096}, + filename=__file__, + triton_meta={'signature': {'out_ptr0': '*i32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_new_zeros_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 17408}}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_new_zeros_1(out_ptr0, xnumel, XBLOCK : tl.constexpr): + xnumel = 2176 + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.full([1], 0, tl.int32) + tl.store(out_ptr0 + (x0), tmp0, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/c2/cc2qlkbbemfommyywsdbow3sqg7jqf5x5tfkbqjzo2qy6lt36yjr.py +# Topologically Sorted Source Nodes: [gt, lt_3, partial_blocks, partial_blocks_1, dense_mask, col_indices, full_blocks, full_blocks_1, dense_mask_1, col_indices_1, dense_mask_2, setitem, arange_4, row_indices, col_range, num_blocks_in_row, child_3, unsqueeze_1, index_mask, child_4, valid_indices, dense_mask_4, setitem_1, arange_6, row_indices_1, col_range_1, num_blocks_in_row_1, child_7, unsqueeze_3, index_mask_1, child_8, valid_indices_1], Original ATen: [aten.gt, aten.lt, aten.bitwise_and, aten._to_copy, aten.sort, aten.eq, aten.new_zeros, aten.arange, aten.unsqueeze, aten.sum, aten.scalar_tensor, aten.where, aten.view, aten.index_put] +# Source node to ATen node mapping: +# arange_4 => iota_4 +# arange_6 => iota_8 +# child_3 => convert_element_type_3 +# child_4 => convert_element_type_4 +# child_7 => convert_element_type_6 +# child_8 => convert_element_type_7 +# col_indices => sort +# col_indices_1 => sort_1 +# col_range => iota_5 +# col_range_1 => iota_9 +# dense_mask => convert_element_type_2 +# dense_mask_1 => convert_element_type_5 +# dense_mask_2 => full_default_1 +# dense_mask_4 => full_default_4 +# full_blocks => eq_1 +# full_blocks_1 => convert_element_type_1 +# gt => gt +# index_mask => lt_4 +# index_mask_1 => lt_5 +# lt_3 => lt_3 +# num_blocks_in_row => sum_2 +# num_blocks_in_row_1 => sum_3 +# partial_blocks => bitwise_and_4 +# partial_blocks_1 => convert_element_type +# row_indices => unsqueeze +# row_indices_1 => unsqueeze_7 +# setitem => full_default_3, index_put, iota_6, iota_7, unsqueeze_2, unsqueeze_3, unsqueeze_4, unsqueeze_5, unsqueeze_6 +# setitem_1 => full_default_6, index_put_1, iota_10, iota_11, unsqueeze_10, unsqueeze_11, unsqueeze_12, unsqueeze_13, unsqueeze_9 +# unsqueeze_1 => unsqueeze_1 +# unsqueeze_3 => unsqueeze_8 +# valid_indices => full_default_2, where +# valid_indices_1 => full_default_5, where_1 +# Graph fragment: +# %sum_1 : Tensor "i64[8, 1, 16, 16][256, 2048, 16, 1]cuda:1" = PlaceHolder[target=sum_1] +# %sum_2 : Tensor "i64[8, 1, 16][16, 128, 1]cuda:1" = PlaceHolder[target=sum_2] +# %sum_3 : Tensor "i64[8, 1, 16][16, 128, 1]cuda:1" = PlaceHolder[target=sum_3] +# %buf2 : Tensor "i16[8, 1, 16, 16][256, 2048, 16, 1]cuda:1" = PlaceHolder[target=buf2] +# %convert_element_type_3 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:1" = PlaceHolder[target=convert_element_type_3] +# %convert_element_type_4 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:1" = PlaceHolder[target=convert_element_type_4] +# %index_put : Tensor "i32[8, 1, 16, 17][272, 272, 17, 1]cuda:1" = PlaceHolder[target=index_put] +# %buf4 : Tensor "i16[8, 1, 16, 16][256, 2048, 16, 1]cuda:1" = PlaceHolder[target=buf4] +# %convert_element_type_6 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:1" = PlaceHolder[target=convert_element_type_6] +# %convert_element_type_7 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:1" = PlaceHolder[target=convert_element_type_7] +# %index_put_1 : Tensor "i32[8, 1, 16, 17][272, 272, 17, 1]cuda:1" = PlaceHolder[target=index_put_1] +# %gt : Tensor "b8[8, 1, 16, 16][256, 256, 16, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {}) +# %lt_3 : Tensor "b8[8, 1, 16, 16][256, 256, 16, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.lt.Scalar](args = (%sum_1, 16384), kwargs = {}) +# %bitwise_and_4 : Tensor "b8[8, 1, 16, 16][256, 256, 16, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%gt, %lt_3), kwargs = {}) +# %convert_element_type : Tensor "i8[8, 1, 16, 16][256, 256, 16, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%bitwise_and_4, torch.int8), kwargs = {}) +# %convert_element_type_2 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:1"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%convert_element_type, torch.int32), kwargs = {}) +# %sort : [num_users=1] = call_function[target=torch.ops.aten.sort.stable](args = (%convert_element_type_2,), kwargs = {stable: True, descending: True}) +# %eq_1 : Tensor "b8[8, 1, 16, 16][256, 256, 16, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%sum_1, 16384), kwargs = {}) +# %convert_element_type_1 : Tensor "i8[8, 1, 16, 16][256, 256, 16, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%eq_1, torch.int8), kwargs = {}) +# %convert_element_type_5 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:1"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%convert_element_type_1, torch.int32), kwargs = {}) +# %sort_1 : [num_users=1] = call_function[target=torch.ops.aten.sort.stable](args = (%convert_element_type_5,), kwargs = {stable: True, descending: True}) +# %full_default_1 : Tensor "i32[8, 1, 16, 17][272, 272, 17, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 1, 16, 17], 0), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:1, pin_memory: False}) +# %iota_7 : Tensor "i64[8][1]cuda:1"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (8,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:1, requires_grad: False}) +# %unsqueeze_4 : Tensor "i64[8, 1][1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_7, -1), kwargs = {}) +# %unsqueeze_5 : Tensor "i64[8, 1, 1][1, 1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_4, -1), kwargs = {}) +# %unsqueeze_6 : Tensor "i64[8, 1, 1, 1][1, 1, 1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_5, -1), kwargs = {}) +# %iota_6 : Tensor "i64[1][1]cuda:1"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (1,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:1, requires_grad: False}) +# %unsqueeze_2 : Tensor "i64[1, 1][1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_6, -1), kwargs = {}) +# %unsqueeze_3 : Tensor "i64[1, 1, 1][1, 1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_2, -1), kwargs = {}) +# %iota_4 : Tensor "i32[16][1]cuda:1"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (16,), kwargs = {start: 0, step: 1, dtype: torch.int32, device: cuda:1, requires_grad: False}) +# %unsqueeze : Tensor "i32[16, 1][1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_4, -1), kwargs = {}) +# %iota_5 : Tensor "i32[16][1]cuda:1"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (16,), kwargs = {start: 0, step: 1, dtype: torch.int32, device: cuda:1, requires_grad: False}) +# %sum_2 : Tensor "i64[8, 1, 16][16, 16, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%convert_element_type_2, [-1]), kwargs = {}) +# %convert_element_type_3 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:1"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_2, torch.int32), kwargs = {}) +# %unsqueeze_1 : Tensor "i32[8, 1, 16, 1][16, 16, 1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_3, 3), kwargs = {}) +# %lt_4 : Tensor "b8[8, 1, 16, 16][256, 256, 16, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_5, %unsqueeze_1), kwargs = {}) +# %convert_element_type_4 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:1"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_1, torch.int32), kwargs = {}) +# %full_default_2 : Tensor "i32[][]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([], 16), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:1, pin_memory: False}) +# %where : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%lt_4, %convert_element_type_4, %full_default_2), kwargs = {}) +# %full_default_3 : Tensor "i32[8, 1, 1, 1][1, 1, 1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 1, 1, 1], 1), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:1, pin_memory: False}) +# %index_put : Tensor "i32[8, 1, 16, 17][272, 272, 17, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.index_put_.default](args = (%full_default_1, [%unsqueeze_6, %unsqueeze_3, %unsqueeze, %where], %full_default_3), kwargs = {}) +# %full_default_4 : Tensor "i32[8, 1, 16, 17][272, 272, 17, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 1, 16, 17], 0), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:1, pin_memory: False}) +# %iota_11 : Tensor "i64[8][1]cuda:1"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (8,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:1, requires_grad: False}) +# %unsqueeze_11 : Tensor "i64[8, 1][1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_11, -1), kwargs = {}) +# %unsqueeze_12 : Tensor "i64[8, 1, 1][1, 1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_11, -1), kwargs = {}) +# %unsqueeze_13 : Tensor "i64[8, 1, 1, 1][1, 1, 1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_12, -1), kwargs = {}) +# %iota_10 : Tensor "i64[1][1]cuda:1"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (1,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:1, requires_grad: False}) +# %unsqueeze_9 : Tensor "i64[1, 1][1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_10, -1), kwargs = {}) +# %unsqueeze_10 : Tensor "i64[1, 1, 1][1, 1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_9, -1), kwargs = {}) +# %iota_8 : Tensor "i32[16][1]cuda:1"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (16,), kwargs = {start: 0, step: 1, dtype: torch.int32, device: cuda:1, requires_grad: False}) +# %unsqueeze_7 : Tensor "i32[16, 1][1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_8, -1), kwargs = {}) +# %iota_9 : Tensor "i32[16][1]cuda:1"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (16,), kwargs = {start: 0, step: 1, dtype: torch.int32, device: cuda:1, requires_grad: False}) +# %sum_3 : Tensor "i64[8, 1, 16][16, 16, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%convert_element_type_5, [-1]), kwargs = {}) +# %convert_element_type_6 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:1"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_3, torch.int32), kwargs = {}) +# %unsqueeze_8 : Tensor "i32[8, 1, 16, 1][16, 16, 1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_6, 3), kwargs = {}) +# %lt_5 : Tensor "b8[8, 1, 16, 16][256, 256, 16, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_9, %unsqueeze_8), kwargs = {}) +# %convert_element_type_7 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:1"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_3, torch.int32), kwargs = {}) +# %full_default_5 : Tensor "i32[][]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([], 16), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:1, pin_memory: False}) +# %where_1 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%lt_5, %convert_element_type_7, %full_default_5), kwargs = {}) +# %full_default_6 : Tensor "i32[8, 1, 1, 1][1, 1, 1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 1, 1, 1], 1), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:1, pin_memory: False}) +# %index_put_1 : Tensor "i32[8, 1, 16, 17][272, 272, 17, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.index_put_.default](args = (%full_default_4, [%unsqueeze_13, %unsqueeze_10, %unsqueeze_7, %where_1], %full_default_6), kwargs = {}) +# return %buf2,%buf4,%sum_2,%sum_3,%convert_element_type_3,%convert_element_type_6,%convert_element_type_4,%buf9,%convert_element_type_7,%buf16 +triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2 = async_compile.triton('triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 128, 'r0_': 16}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr4': '*i32', 'out_ptr5': '*i32', 'out_ptr6': '*i32', 'out_ptr7': '*i32', 'out_ptr8': '*i32', 'out_ptr9': '*i32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2', 'mutated_arg_names': ['out_ptr7', 'out_ptr9'], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2(in_ptr0, out_ptr4, out_ptr5, out_ptr6, out_ptr7, out_ptr8, out_ptr9, xnumel, r0_numel, XBLOCK : tl.constexpr): + xnumel = 128 + r0_numel = 16 + R0_BLOCK: tl.constexpr = 16 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + x0 = xindex + tmp0 = tl.load(in_ptr0 + (r0_1 + 16*x0), xmask, other=0.0) + tmp1 = tl.full([1, 1], 0, tl.int64) + tmp2 = tmp0 > tmp1 + tmp3 = tl.full([1, 1], 16384, tl.int64) + tmp4 = tmp0 < tmp3 + tmp5 = tmp2 & tmp4 + tmp6 = tmp5.to(tl.int8) + tmp7 = tmp6.to(tl.int32) + tmp8 = r0_1 + tmp9 = tmp8.to(tl.int16) + tmp10 = tl.broadcast_to(tmp7, [XBLOCK, R0_BLOCK]) + tmp11 = tl.broadcast_to(tmp9, [XBLOCK, R0_BLOCK]) + tmp12, tmp13, = triton_helpers.sort_with_index(tmp10, tmp11, None, 1, stable=True, descending=True) + tmp14 = tmp0 == tmp3 + tmp15 = tmp14.to(tl.int8) + tmp16 = tmp15.to(tl.int32) + tmp17 = tl.broadcast_to(tmp16, [XBLOCK, R0_BLOCK]) + tmp18, tmp19, = triton_helpers.sort_with_index(tmp17, tmp11, None, 1, stable=True, descending=True) + tmp20 = tmp7.to(tl.int64) + tmp21 = tl.broadcast_to(tmp20, [XBLOCK, R0_BLOCK]) + tmp23 = tl.where(xmask, tmp21, 0) + tmp24 = tl.sum(tmp23, 1)[:, None].to(tl.int64) + tmp25 = tmp16.to(tl.int64) + tmp26 = tl.broadcast_to(tmp25, [XBLOCK, R0_BLOCK]) + tmp28 = tl.where(xmask, tmp26, 0) + tmp29 = tl.sum(tmp28, 1)[:, None].to(tl.int64) + tmp30 = tmp24.to(tl.int32) + tmp31 = tmp29.to(tl.int32) + tmp32 = tmp13.to(tl.int64) + tmp33 = tmp32.to(tl.int32) + tmp34 = tmp8 < tmp30 + tmp35 = tl.full([1, 1], 16, tl.int32) + tmp36 = tl.where(tmp34, tmp33, tmp35) + tmp37 = tl.full([XBLOCK, R0_BLOCK], 17, tl.int32) + tmp38 = tmp36 + tmp37 + tmp39 = tmp36 < 0 + tmp40 = tl.where(tmp39, tmp38, tmp36) + tl.device_assert(((0 <= tmp40) & (tmp40 < 17)) | ~(xmask), "index out of bounds: 0 <= tmp40 < 17") + tmp42 = tl.full([1, 1], 1, tl.int32) + tmp43 = tmp19.to(tl.int64) + tmp44 = tmp43.to(tl.int32) + tmp45 = tmp8 < tmp31 + tmp46 = tl.where(tmp45, tmp44, tmp35) + tmp47 = tmp46 + tmp37 + tmp48 = tmp46 < 0 + tmp49 = tl.where(tmp48, tmp47, tmp46) + tl.device_assert(((0 <= tmp49) & (tmp49 < 17)) | ~(xmask), "index out of bounds: 0 <= tmp49 < 17") + tl.store(out_ptr4 + (x0), tmp30, xmask) + tl.store(out_ptr5 + (x0), tmp31, xmask) + tl.store(out_ptr6 + (r0_1 + 16*x0), tmp33, xmask) + tl.store(out_ptr7 + (tl.broadcast_to(tmp40 + 17*x0, [XBLOCK, R0_BLOCK])), tmp42, xmask) + tl.store(out_ptr8 + (r0_1 + 16*x0), tmp44, xmask) + tl.store(out_ptr9 + (tl.broadcast_to(tmp49 + 17*x0, [XBLOCK, R0_BLOCK])), tmp42, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/nj/cnjktwj7h4iwx4zghbum5atne46yt4ce4t5jnkkvyag35pn7glnh.py +# Topologically Sorted Source Nodes: [batched_outputs_3, transpose, col_indices_2, q_indices, num_blocks_in_row_2, q_num_blocks], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sort, aten._to_copy, aten.sum] +# Source node to ATen node mapping: +# batched_outputs_3 => clone_4, slice_2 +# col_indices_2 => sort_2 +# num_blocks_in_row_2 => sum_4 +# q_indices => clone_6, convert_element_type_9 +# q_num_blocks => convert_element_type_8 +# transpose => permute_1 +# Graph fragment: +# %buf9 : Tensor "i32[8, 1, 16, 17][272, 272, 17, 1]cuda:1" = PlaceHolder[target=buf9] +# %buf11 : Tensor "i16[8, 1, 16, 16][256, 2048, 16, 1]cuda:1" = PlaceHolder[target=buf11] +# %sum_4 : Tensor "i64[8, 1, 16][16, 128, 1]cuda:1" = PlaceHolder[target=sum_4] +# %slice_2 : Tensor "i32[8, 1, 16, 16][272, 272, 17, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%index_put, 3, 0, 16), kwargs = {}) +# %clone_4 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%slice_2,), kwargs = {memory_format: torch.contiguous_format}) +# %permute_1 : Tensor "i32[8, 1, 16, 16][256, 256, 1, 16]cuda:1"[num_users=2] = call_function[target=torch.ops.aten.permute.default](args = (%clone_4, [0, 1, 3, 2]), kwargs = {}) +# %sort_2 : [num_users=1] = call_function[target=torch.ops.aten.sort.stable](args = (%permute_1,), kwargs = {stable: True, descending: True}) +# %convert_element_type_9 : Tensor "i32[8, 1, 16, 16][256, 256, 1, 16]cuda:1"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_5, torch.int32), kwargs = {}) +# %clone_6 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%convert_element_type_9,), kwargs = {memory_format: torch.contiguous_format}) +# %sum_4 : Tensor "i64[8, 1, 16][16, 16, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%permute_1, [-1]), kwargs = {}) +# %convert_element_type_8 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_4, torch.int32), kwargs = {}) +# return %buf11,%sum_4,%clone_6,%convert_element_type_8 +triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3 = async_compile.triton('triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 128, 'r0_': 16}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr2': '*i32', 'out_ptr3': '*i32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 1024, 'r0_': 16384}} +) +@triton.jit +def triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3(in_ptr0, out_ptr2, out_ptr3, xnumel, r0_numel, XBLOCK : tl.constexpr): + xnumel = 128 + r0_numel = 16 + R0_BLOCK: tl.constexpr = 16 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + x0 = (xindex % 16) + x1 = xindex // 16 + x3 = xindex + tmp0 = tl.load(in_ptr0 + (x0 + 17*r0_2 + 272*x1), xmask, other=0.0) + tmp1 = r0_2 + tmp2 = tmp1.to(tl.int16) + tmp3 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp4 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5, tmp6, = triton_helpers.sort_with_index(tmp3, tmp4, None, 1, stable=True, descending=True) + tmp7 = tmp0.to(tl.int64) + tmp8 = tl.broadcast_to(tmp7, [XBLOCK, R0_BLOCK]) + tmp10 = tl.where(xmask, tmp8, 0) + tmp11 = tl.sum(tmp10, 1)[:, None].to(tl.int64) + tmp12 = tmp6.to(tl.int64) + tmp13 = tmp12.to(tl.int32) + tmp14 = tmp11.to(tl.int32) + tl.store(out_ptr2 + (r0_2 + 16*x3), tmp13, xmask) + tl.store(out_ptr3 + (x3), tmp14, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, = args + args.clear() + assert_size_stride(arg0_1, (8, ), (1, )) + with torch.cuda._DeviceGuard(1): + torch.cuda.set_device(1) + buf0 = empty_strided_cuda((8, 1, 16, 16), (256, 2048, 16, 1), torch.int64) + # Topologically Sorted Source Nodes: [result_1, m, causal_mask, n, b, index, lt, padding_mask, index_1, lt_1, and_2, suffix_mask, remainder, index_2, padding_mask_1, and_3, and_4, sub, remainder_1, diagnol_mask, result_2, batched_outputs_2, mask_2, mask_3, mask_block_sum], Original ATen: [aten.view, aten.arange, aten.ge, aten.index, aten.lt, aten.bitwise_and, aten.bitwise_or, aten.remainder, aten.sub, aten.eq, aten.permute, aten.sum] + stream1 = get_raw_stream(1) + triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0.run(arg0_1, buf0, 2048, 16384, stream=stream1) + del arg0_1 + buf15 = empty_strided_cuda((8, 1, 16, 17), (272, 272, 17, 1), torch.int32) + # Topologically Sorted Source Nodes: [dense_mask_4], Original ATen: [aten.new_zeros] + stream1 = get_raw_stream(1) + triton_poi_fused_new_zeros_1.run(buf15, 2176, stream=stream1) + buf8 = empty_strided_cuda((8, 1, 16, 17), (272, 272, 17, 1), torch.int32) + # Topologically Sorted Source Nodes: [dense_mask_2], Original ATen: [aten.new_zeros] + stream1 = get_raw_stream(1) + triton_poi_fused_new_zeros_1.run(buf8, 2176, stream=stream1) + buf6 = empty_strided_cuda((8, 1, 16), (16, 16, 1), torch.int32) + buf13 = empty_strided_cuda((8, 1, 16), (16, 16, 1), torch.int32) + buf7 = empty_strided_cuda((8, 1, 16, 16), (256, 256, 16, 1), torch.int32) + buf14 = empty_strided_cuda((8, 1, 16, 16), (256, 256, 16, 1), torch.int32) + # Topologically Sorted Source Nodes: [gt, lt_3, partial_blocks, partial_blocks_1, dense_mask, col_indices, full_blocks, full_blocks_1, dense_mask_1, col_indices_1, dense_mask_2, setitem, arange_4, row_indices, col_range, num_blocks_in_row, child_3, unsqueeze_1, index_mask, child_4, valid_indices, dense_mask_4, setitem_1, arange_6, row_indices_1, col_range_1, num_blocks_in_row_1, child_7, unsqueeze_3, index_mask_1, child_8, valid_indices_1], Original ATen: [aten.gt, aten.lt, aten.bitwise_and, aten._to_copy, aten.sort, aten.eq, aten.new_zeros, aten.arange, aten.unsqueeze, aten.sum, aten.scalar_tensor, aten.where, aten.view, aten.index_put] + stream1 = get_raw_stream(1) + triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.run(buf0, buf6, buf13, buf7, buf8, buf14, buf15, 128, 16, stream=stream1) + del buf0 + buf22 = empty_strided_cuda((8, 1, 16, 16), (256, 256, 16, 1), torch.int32) + buf24 = empty_strided_cuda((8, 1, 16), (16, 16, 1), torch.int32) + # Topologically Sorted Source Nodes: [batched_outputs_3, transpose, col_indices_2, q_indices, num_blocks_in_row_2, q_num_blocks], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sort, aten._to_copy, aten.sum] + stream1 = get_raw_stream(1) + triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3.run(buf8, buf22, buf24, 128, 16, stream=stream1) + del buf8 + buf19 = empty_strided_cuda((8, 1, 16, 16), (256, 256, 16, 1), torch.int32) + buf21 = empty_strided_cuda((8, 1, 16), (16, 16, 1), torch.int32) + # Topologically Sorted Source Nodes: [batched_outputs_5, transpose_1, col_indices_3, full_q_indices, num_blocks_in_row_3, full_q_num_blocks], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sort, aten._to_copy, aten.sum] + stream1 = get_raw_stream(1) + triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3.run(buf15, buf19, buf21, 128, 16, stream=stream1) + del buf15 + return (buf19, buf21, buf22, buf24, buf14, buf13, buf7, buf6, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = rand_strided((8, ), (1, ), device='cuda:1', dtype=torch.int64) + fn = lambda: call([arg0_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/5p/c5pbkg5eq64emuv25ukki7a5dxvn2p2sh6jeiwb6b54tbidps5w7.py b/SpecForge-ext/cache/compiled_kernels/5p/c5pbkg5eq64emuv25ukki7a5dxvn2p2sh6jeiwb6b54tbidps5w7.py new file mode 100644 index 0000000000000000000000000000000000000000..3b40b261ed3f30a7c29640988ce81d597237c53b --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/5p/c5pbkg5eq64emuv25ukki7a5dxvn2p2sh6jeiwb6b54tbidps5w7.py @@ -0,0 +1,552 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'in_ptr9': '*i64', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 2097152, 262144, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 2097152, 262144, 128, 1 + + ZQ = 8 + HQ = 32 + Q_LEN = 2048 + ZKV = 8 + KV_LEN = 2048 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 8 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 256 + stride_kv_idx_m = 16 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 262144*idx_hq + 8388608*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m + 8388608*idx_zq, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr9 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/5s/c5siycmobmba5rqczjfbtd45di6el6qnpizugzs3hsg4jzkcqnpk.py b/SpecForge-ext/cache/compiled_kernels/5s/c5siycmobmba5rqczjfbtd45di6el6qnpizugzs3hsg4jzkcqnpk.py new file mode 100644 index 0000000000000000000000000000000000000000..7e9f5739d4887de4cdf62ab33f9110691b43ea75 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/5s/c5siycmobmba5rqczjfbtd45di6el6qnpizugzs3hsg4jzkcqnpk.py @@ -0,0 +1,161 @@ +# AOT ID: ['11_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/lw/clwnecq6ifpvev5aiszbhu6i732z6eomppbbe2l6ohgsvjmgczzn.py +# Topologically Sorted Source Nodes: [target_head, target_p], Original ATen: [aten._to_copy, prims.prepare_softmax_online, aten.sub, aten.exp, aten._softmax] +# Source node to ATen node mapping: +# target_head => convert_element_type +# target_p => div +# Graph fragment: +# %arg1_1 : Tensor "bf16[2, s67, 32000][32000*s67, 32000, 1]cuda:4" = PlaceHolder[target=arg1_1] +# %getitem : Tensor "f32[2, s67, 1][s67, 1, 2*s67]cuda:4" = PlaceHolder[target=getitem] +# %getitem_1 : Tensor "f32[2, s67, 1][s67, 1, 2*s67]cuda:4" = PlaceHolder[target=getitem_1] +# %convert_element_type : Tensor "f32[2, s67, 32000][32000*s67, 32000, 1]cuda:4"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%arg1_1, torch.float32), kwargs = {}) +# %prepare_softmax_online_default : [num_users=2] = call_function[target=torch.ops.prims.prepare_softmax_online.default](args = (%convert_element_type, 2), kwargs = {}) +# %sub_tensor : Tensor "f32[2, s67, 32000][32000*s67, 32000, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%convert_element_type, %getitem), kwargs = {}) +# %exp_default : Tensor "f32[2, s67, 32000][32000*s67, 32000, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.exp.default](args = (%sub_tensor,), kwargs = {}) +# %div : Tensor "f32[2, s67, 32000][32000*s67, 32000, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%exp_default, %getitem_1), kwargs = {}) +# return %getitem,%getitem_1,%div +triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0 = async_compile.triton('triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr2': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0(in_ptr0, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp3_max = tl.full([XBLOCK, R0_BLOCK], float('-inf'), tl.float32) + _tmp3_sum = tl.zeros([XBLOCK, R0_BLOCK], tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp1 = tmp0.to(tl.float32) + tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) + + _tmp3_max_next, _tmp3_sum_next = triton_helpers.online_softmax_combine( + _tmp3_max, _tmp3_sum, tmp2, False + ) + + _tmp3_max = tl.where(r0_mask & xmask, _tmp3_max_next, _tmp3_max) + _tmp3_sum = tl.where(r0_mask & xmask, _tmp3_sum_next, _tmp3_sum) + + tmp3, tmp4 = triton_helpers.online_softmax_reduce( + _tmp3_max, _tmp3_sum, 1, False) + tmp3 = tmp3[:, None] + tmp4 = tmp4[:, None] + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp5 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp6 = tmp5.to(tl.float32) + tmp7 = tmp6 - tmp3 + tmp8 = libdevice.exp(tmp7) + tmp9 = (tmp8 / tmp4) + tl.store(out_ptr2 + (r0_1 + 32000*x0), tmp9, r0_mask & xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1 = args + args.clear() + s67 = arg0_1 + assert_size_stride(arg1_1, (2, s67, 32000), (32000*s67, 32000, 1)) + with torch.cuda._DeviceGuard(4): + torch.cuda.set_device(4) + buf2 = empty_strided_cuda((2, s67, 32000), (32000*s67, 32000, 1), torch.float32) + # Topologically Sorted Source Nodes: [target_head, target_p], Original ATen: [aten._to_copy, prims.prepare_softmax_online, aten.sub, aten.exp, aten._softmax] + triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0_xnumel = 2*s67 + stream4 = get_raw_stream(4) + triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0.run(arg1_1, buf2, triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0_xnumel, 32000, stream=stream4) + del arg1_1 + return (buf2, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = 1543 + arg1_1 = rand_strided((2, 1543, 32000), (49376000, 32000, 1), device='cuda:4', dtype=torch.bfloat16) + fn = lambda: call([arg0_1, arg1_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/5u/235c5fbee66a14cc3d65896905ec816ec90c51ba6594c4a627960306977eb07c.best_config b/SpecForge-ext/cache/compiled_kernels/5u/235c5fbee66a14cc3d65896905ec816ec90c51ba6594c4a627960306977eb07c.best_config new file mode 100644 index 0000000000000000000000000000000000000000..26f34a32396bf93c323cd255a2cf49b0585d7f4b --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/5u/235c5fbee66a14cc3d65896905ec816ec90c51ba6594c4a627960306977eb07c.best_config @@ -0,0 +1 @@ +{"XBLOCK": 32, "R0_BLOCK": 16, "num_warps": 4, "num_stages": 1, "configs_hash": "21ad1ee516cd6d15e1fb8e88c10082cd54bef654f8a281c7d5ccd54b6509a685", "found_by_coordesc": false, "time_taken_ms": 28, "triton_cache_hash": "2HBOMUT44J5WFCUWYGRFAAS3HGVNDHLHT7HCSXUCAOIKU6XGJNTA"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/6b/c6beknosybos5d54llineldguuueh3kpjlkiuzm4pkorx7g6mjh6.py b/SpecForge-ext/cache/compiled_kernels/6b/c6beknosybos5d54llineldguuueh3kpjlkiuzm4pkorx7g6mjh6.py new file mode 100644 index 0000000000000000000000000000000000000000..d6013163c6e3d52b314f9edcb7a722beaf9d485e --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/6b/c6beknosybos5d54llineldguuueh3kpjlkiuzm4pkorx7g6mjh6.py @@ -0,0 +1,45 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_argmax_0(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask & xmask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask & xmask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tl.store(out_ptr0 + (x0), tmp2, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/6b/c6bpf3ctcqs5wvcac26go3fcp5hdc2pxduwgba2cnxt52xqmp6mq.py b/SpecForge-ext/cache/compiled_kernels/6b/c6bpf3ctcqs5wvcac26go3fcp5hdc2pxduwgba2cnxt52xqmp6mq.py new file mode 100644 index 0000000000000000000000000000000000000000..f20db2602979a76af8a03d42108ae4fe34355a9f --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/6b/c6bpf3ctcqs5wvcac26go3fcp5hdc2pxduwgba2cnxt52xqmp6mq.py @@ -0,0 +1,334 @@ +# AOT ID: ['2_backward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/zg/czg53pk3l24wn74a6bylpzbgb44kx2zfplies7n5uiiogfzwg4z2.py +# Topologically Sorted Source Nodes: [hidden_states, hidden_states_1, to_1], Original ATen: [aten._to_copy, aten.mul, aten.sum] +# Source node to ATen node mapping: +# hidden_states => convert_element_type +# hidden_states_1 => mul_16 +# to_1 => convert_element_type_1 +# Graph fragment: +# %tangents_1 : Tensor "bf16[s47, s87, s33][s33*s87, s33, 1]cuda:7" = PlaceHolder[target=tangents_1] +# %primals_4 : Tensor "bf16[s47, s87, s33][s33*s87, s33, 1]cuda:7" = PlaceHolder[target=primals_4] +# %rsqrt : Tensor "f32[s47, s87, 1][s87, 1, 1]cuda:7" = PlaceHolder[target=rsqrt] +# %convert_element_type : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:7"[num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%primals_4, torch.float32), kwargs = {}) +# %mul_16 : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type, %rsqrt), kwargs = {}) +# %convert_element_type_1 : Tensor "bf16[s47, s87, s33][s33*s87, s33, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_16, torch.bfloat16), kwargs = {}) +# %mul_28 : Tensor "bf16[s47, s87, s33][s33*s87, s33, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_1, %convert_element_type_1), kwargs = {}) +# %sum_1 : Tensor "bf16[1, 1, s33][s33, s33, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_28, [0, 1], True), kwargs = {}) +# return %buf0 +triton_red_fused__to_copy_mul_sum_0 = async_compile.triton('triton_red_fused__to_copy_mul_sum_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 131072, 'r0_': 128}, + reduction_hint=ReductionHint.OUTER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_mul_sum_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_mul_sum_0(in_ptr0, in_ptr1, in_ptr2, out_ptr0, ks0, ks1, ks2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x1 = xindex // ks0 + x0 = (xindex % ks0) + _tmp13 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + x3 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = r0_2 + x1*((31 + ks1*ks2) // 32) + tmp1 = ks1*ks2 + tmp2 = tmp0 < tmp1 + tmp3 = tl.load(in_ptr0 + (x0 + ks0*(((r0_2 + x1*((31 + ks1*ks2) // 32)) % (ks1*ks2)))), r0_mask & tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp4 = tl.load(in_ptr1 + (x0 + ks0*(((r0_2 + x1*((31 + ks1*ks2) // 32)) % (ks1*ks2)))), r0_mask & tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp5 = tmp4.to(tl.float32) + tmp6 = tl.load(in_ptr2 + (((r0_2 + x1*((31 + ks1*ks2) // 32)) % (ks1*ks2))), r0_mask & tmp2 & xmask, eviction_policy='evict_last', other=0.0) + tmp7 = tmp5 * tmp6 + tmp8 = tmp7.to(tl.float32) + tmp9 = tmp3 * tmp8 + tmp10 = tl.full(tmp9.shape, 0, tmp9.dtype) + tmp11 = tl.where(tmp2, tmp9, tmp10) + tmp12 = tl.broadcast_to(tmp11, [XBLOCK, R0_BLOCK]) + tmp14 = _tmp13 + tmp12 + _tmp13 = tl.where(r0_mask & xmask, tmp14, _tmp13) + tmp13 = tl.sum(_tmp13, 1)[:, None] + tl.store(out_ptr0 + (x3), tmp13, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ut/cutp3chhk5c6s5fxb2gqzhrx5hjq4ltt3ybguoemttw3toknshg6.py +# Topologically Sorted Source Nodes: [hidden_states, hidden_states_1, to_1], Original ATen: [aten._to_copy, aten.mul, aten.sum] +# Source node to ATen node mapping: +# hidden_states => convert_element_type +# hidden_states_1 => mul_16 +# to_1 => convert_element_type_1 +# Graph fragment: +# %buf0 : Tensor "f32[1, 1, s33, 32][32*s33, 32*s33, 1, s33]cuda:7" = PlaceHolder[target=buf0] +# %convert_element_type : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:7"[num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%primals_4, torch.float32), kwargs = {}) +# %mul_16 : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type, %rsqrt), kwargs = {}) +# %convert_element_type_1 : Tensor "bf16[s47, s87, s33][s33*s87, s33, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_16, torch.bfloat16), kwargs = {}) +# %mul_28 : Tensor "bf16[s47, s87, s33][s33*s87, s33, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_1, %convert_element_type_1), kwargs = {}) +# %sum_1 : Tensor "bf16[1, 1, s33][s33, s33, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_28, [0, 1], True), kwargs = {}) +# return %sum_1 +triton_per_fused__to_copy_mul_sum_1 = async_compile.triton('triton_per_fused__to_copy_mul_sum_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 4096, 'r0_': 32}, + reduction_hint=ReductionHint.OUTER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*bf16', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_mul_sum_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused__to_copy_mul_sum_1(in_ptr0, out_ptr0, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 32 + R0_BLOCK: tl.constexpr = 32 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + x0 = xindex + tmp0 = tl.load(in_ptr0 + (x0 + ks0*r0_1), xmask, other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp3 = tl.where(xmask, tmp1, 0) + tmp4 = tl.sum(tmp3, 1)[:, None].to(tl.float32) + tl.store(out_ptr0 + (x0), tmp4, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/im/cimq7s4zgz63carjnhuvinchsq4odrr475l6qsymkihvbxvheq7a.py +# Topologically Sorted Source Nodes: [hidden_states], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.pow, aten.expand, aten.div, aten.add] +# Source node to ATen node mapping: +# hidden_states => convert_element_type +# Graph fragment: +# %tangents_1 : Tensor "bf16[s47, s87, s33][s33*s87, s33, 1]cuda:7" = PlaceHolder[target=tangents_1] +# %primals_7 : Tensor "bf16[s33][1]cuda:7" = PlaceHolder[target=primals_7] +# %primals_4 : Tensor "bf16[s47, s87, s33][s33*s87, s33, 1]cuda:7" = PlaceHolder[target=primals_4] +# %rsqrt : Tensor "f32[s47, s87, 1][s87, 1, 1]cuda:7" = PlaceHolder[target=rsqrt] +# %sum_2 : Tensor "f32[s47, s87, 1][s87, 1, s47*s87]cuda:7" = PlaceHolder[target=sum_2] +# %mul_27 : Tensor "bf16[s47, s87, s33][s33*s87, s33, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_1, %primals_7), kwargs = {}) +# %convert_element_type : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:7"[num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%primals_4, torch.float32), kwargs = {}) +# %convert_element_type_2 : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:7"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_27, torch.float32), kwargs = {}) +# %mul_29 : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_2, %convert_element_type), kwargs = {}) +# %mul_30 : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_2, %rsqrt), kwargs = {}) +# %sum_2 : Tensor "f32[s47, s87, 1][s87, 1, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_29, [2], True), kwargs = {}) +# %pow_2 : Tensor "f32[s47, s87, 1][s87, 1, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%rsqrt, 3), kwargs = {}) +# %mul_31 : Tensor "f32[s47, s87, 1][s87, 1, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%sum_2, -0.5), kwargs = {}) +# %mul_32 : Tensor "f32[s47, s87, 1][s87, 1, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_31, %pow_2), kwargs = {}) +# %expand : Tensor "f32[s47, s87, s33][s87, 1, 0]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%mul_32, [%primals_1, %primals_2, %primals_3]), kwargs = {}) +# %div : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand, %primals_3), kwargs = {}) +# %pow_3 : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type, 1.0), kwargs = {}) +# %mul_33 : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_3, 2.0), kwargs = {}) +# %mul_34 : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div, %mul_33), kwargs = {}) +# %add_37 : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_30, %mul_34), kwargs = {}) +# %convert_element_type_3 : Tensor "bf16[s47, s87, s33][s33*s87, s33, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_37, torch.bfloat16), kwargs = {}) +# return %sum_2,%convert_element_type_3 +triton_red_fused__to_copy_add_div_expand_mul_pow_sum_2 = async_compile.triton('triton_red_fused__to_copy_add_div_expand_mul_pow_sum_2', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 4096}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_div_expand_mul_pow_sum_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 7, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_add_div_expand_mul_pow_sum_2(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp8 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp4 = tl.load(in_ptr2 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp2 = tmp0 * tmp1 + tmp3 = tmp2.to(tl.float32) + tmp5 = tmp4.to(tl.float32) + tmp6 = tmp3 * tmp5 + tmp7 = tl.broadcast_to(tmp6, [XBLOCK, R0_BLOCK]) + tmp9 = _tmp8 + tmp7 + _tmp8 = tl.where(r0_mask & xmask, tmp9, _tmp8) + tmp8 = tl.sum(_tmp8, 1)[:, None] + tmp14 = tl.load(in_ptr3 + (x0), xmask, eviction_policy='evict_last') + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp10 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp11 = tl.load(in_ptr1 + (r0_1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp24 = tl.load(in_ptr2 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp12 = tmp10 * tmp11 + tmp13 = tmp12.to(tl.float32) + tmp15 = tmp13 * tmp14 + tmp16 = -0.5 + tmp17 = tmp8 * tmp16 + tmp18 = tmp14 * tmp14 + tmp19 = tmp18 * tmp14 + tmp20 = tmp17 * tmp19 + tmp21 = ks0 + tmp22 = tmp21.to(tl.float32) + tmp23 = (tmp20 / tmp22) + tmp25 = tmp24.to(tl.float32) + tmp26 = 2.0 + tmp27 = tmp25 * tmp26 + tmp28 = tmp23 * tmp27 + tmp29 = tmp15 + tmp28 + tmp30 = tmp29.to(tl.float32) + tl.store(out_ptr1 + (r0_1 + ks0*x0), tmp30, r0_mask & xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_6, primals_4, primals_7, rsqrt, tangents_1 = args + args.clear() + s47 = primals_1 + s87 = primals_2 + s33 = primals_3 + s82 = primals_6 + assert_size_stride(primals_4, (s47, s87, s33), (s33*s87, s33, 1)) + assert_size_stride(primals_7, (s33, ), (1, )) + assert_size_stride(rsqrt, (s47, s87, 1), (s87, 1, 1)) + assert_size_stride(tangents_1, (s47, s87, s33), (s33*s87, s33, 1)) + with torch.cuda._DeviceGuard(7): + torch.cuda.set_device(7) + buf0 = empty_strided_cuda((1, 1, s33, 32), (32*s33, 32*s33, 1, s33), torch.float32) + # Topologically Sorted Source Nodes: [hidden_states, hidden_states_1, to_1], Original ATen: [aten._to_copy, aten.mul, aten.sum] + triton_red_fused__to_copy_mul_sum_0_xnumel = 32*s33 + triton_red_fused__to_copy_mul_sum_0_r0_numel = (31 + s47*s87) // 32 + stream7 = get_raw_stream(7) + triton_red_fused__to_copy_mul_sum_0.run(tangents_1, primals_4, rsqrt, buf0, s33, s47, s87, triton_red_fused__to_copy_mul_sum_0_xnumel, triton_red_fused__to_copy_mul_sum_0_r0_numel, stream=stream7) + buf1 = empty_strided_cuda((1, 1, s33), (s33, s33, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [hidden_states, hidden_states_1, to_1], Original ATen: [aten._to_copy, aten.mul, aten.sum] + stream7 = get_raw_stream(7) + triton_per_fused__to_copy_mul_sum_1.run(buf0, buf1, s33, s33, 32, stream=stream7) + del buf0 + buf3 = empty_strided_cuda((s47, s87, s33), (s33*s87, s33, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [hidden_states], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.pow, aten.expand, aten.div, aten.add] + triton_red_fused__to_copy_add_div_expand_mul_pow_sum_2_xnumel = s47*s87 + stream7 = get_raw_stream(7) + triton_red_fused__to_copy_add_div_expand_mul_pow_sum_2.run(tangents_1, primals_7, primals_4, rsqrt, buf3, s33, triton_red_fused__to_copy_add_div_expand_mul_pow_sum_2_xnumel, s33, stream=stream7) + del primals_4 + del primals_7 + del rsqrt + del tangents_1 + return (None, None, None, buf3, None, None, reinterpret_tensor(buf1, (s33, ), (1, ), 0), ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = 2 + primals_2 = 2048 + primals_3 = 4096 + primals_6 = 840433664 + primals_4 = rand_strided((2, 2048, 4096), (8388608, 4096, 1), device='cuda:7', dtype=torch.bfloat16) + primals_7 = rand_strided((4096, ), (1, ), device='cuda:7', dtype=torch.bfloat16) + rsqrt = rand_strided((2, 2048, 1), (2048, 1, 1), device='cuda:7', dtype=torch.float32) + tangents_1 = rand_strided((2, 2048, 4096), (8388608, 4096, 1), device='cuda:7', dtype=torch.bfloat16) + fn = lambda: call([primals_1, primals_2, primals_3, primals_6, primals_4, primals_7, rsqrt, tangents_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/6j/b801eb968d13baeef00c09ffebb7c203c75661545f70c7ec4ed906e946ad8a67.best_config b/SpecForge-ext/cache/compiled_kernels/6j/b801eb968d13baeef00c09ffebb7c203c75661545f70c7ec4ed906e946ad8a67.best_config new file mode 100644 index 0000000000000000000000000000000000000000..2a95815b49cfc301dd2a3d06bb1b105b04bfbae7 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/6j/b801eb968d13baeef00c09ffebb7c203c75661545f70c7ec4ed906e946ad8a67.best_config @@ -0,0 +1 @@ +{"XBLOCK": 256, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 11, "triton_cache_hash": "XAIV2GWX5UZL7NNOCKNWC2I6AATKI6664P6FTQPRXS2M4AR4WJWA"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/6j/c6jx5fvfijye7zqqg42xonpcdfuwatv7bizrwompd5o3dua57uju.py b/SpecForge-ext/cache/compiled_kernels/6j/c6jx5fvfijye7zqqg42xonpcdfuwatv7bizrwompd5o3dua57uju.py new file mode 100644 index 0000000000000000000000000000000000000000..5839e0a1e772d18eae23a4aa17d00333f20ee314 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/6j/c6jx5fvfijye7zqqg42xonpcdfuwatv7bizrwompd5o3dua57uju.py @@ -0,0 +1,24 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 8192}, + filename=__file__, + triton_meta={'signature': {'out_ptr0': '*i32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_new_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_new_zeros_0(out_ptr0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.full([1], 0, tl.int32) + tl.store(out_ptr0 + (x0), tmp0, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/6o/c6obqatzdeyb7elxstetxuvmlhbvwph6buxkixqs4flvdn2x6vgl.py b/SpecForge-ext/cache/compiled_kernels/6o/c6obqatzdeyb7elxstetxuvmlhbvwph6buxkixqs4flvdn2x6vgl.py new file mode 100644 index 0000000000000000000000000000000000000000..c91fa8bbcbcfee85c7f3ff18c31c1a55315fee63 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/6o/c6obqatzdeyb7elxstetxuvmlhbvwph6buxkixqs4flvdn2x6vgl.py @@ -0,0 +1,835 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention_backward(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks0, 128*ks0, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks0, 128*ks0, 128, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 8388608, 262144, 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 8388608, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks0, 128*ks0, 128, 1 + + ZQ = 2 + HQ = 32 + HKV = 8 + Q_LEN = 2048 + ZKV = 2 + KV_LEN = ks0 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 2 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 16*ks1 + stride_kv_idx_m = ks1 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = ks2 + stride_q_idx_h = 16*ks3 + stride_q_idx_n = 16 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks0 + 1024*off_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = ks0 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr16 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp39 = (ds) + grad_scores = tmp39 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = ks0 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp40 = (qkT) + post_mod_scores = tmp40 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp41 = tl.full([1], False, tl.int1) + tmp42 = (m) + tmp43 = (n) + tmp44 = tmp42 >= tmp43 + tmp45 = tmp43.to(tl.int64) + tmp46 = (off_z) + tmp47 = tl.load(in_ptr16 + tmp46) + tmp48 = tmp45 < tmp47 + tmp49 = tmp42.to(tl.int64) + tmp50 = tmp49 < tmp47 + tmp51 = tmp48 & tmp50 + tmp52 = tmp44 & tmp51 + tmp53 = tmp41 | tmp52 + tmp54 = tl.full([1], 2048, tl.int32) + tmp55 = tmp43 >= tmp54 + tmp56 = (tmp43 % tmp54) + tmp57 = tl.full([1], 0, tl.int32) + tmp58 = tmp56 != tmp57 + tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0 + tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0 + tmp61 = tmp59 != tmp60 + tmp62 = tmp58 & tmp61 + tmp63 = tmp56 + tmp54 + tmp64 = tl.where(tmp62, tmp63, tmp56) + tmp65 = tmp64.to(tl.int64) + tmp66 = tmp65 < tmp47 + tmp67 = tmp55 & tmp66 + tmp68 = tmp43 - tmp42 + tmp69 = (tmp68 % tmp54) + tmp70 = tmp69 != tmp57 + tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0 + tmp72 = tmp71 != tmp60 + tmp73 = tmp70 & tmp72 + tmp74 = tmp69 + tmp54 + tmp75 = tl.where(tmp73, tmp74, tmp69) + tmp76 = tmp75 == tmp57 + tmp77 = tmp67 & tmp76 + tmp78 = tmp53 | tmp77 + mask_mod_output = tmp78 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp79 = (dsT) + grad_scores = tmp79 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/7g/59ff39d5526de7bb833fbd386ca3ce564bdaf6828f559a423e599b5ad90d0456.best_config b/SpecForge-ext/cache/compiled_kernels/7g/59ff39d5526de7bb833fbd386ca3ce564bdaf6828f559a423e599b5ad90d0456.best_config new file mode 100644 index 0000000000000000000000000000000000000000..a337a719c6503c8dcbad0c427c4a5067600d0bd0 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/7g/59ff39d5526de7bb833fbd386ca3ce564bdaf6828f559a423e599b5ad90d0456.best_config @@ -0,0 +1 @@ +{"XBLOCK": 256, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 11, "triton_cache_hash": "6FB7I6IASCIGI3DSKLBL4Q2CXFFWPYWXW7AMHNUUDLPGKUCB3PDA"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/7m/c7mmadjna7dltm72lxvsoktdadnw2jtxufsj2eoflefh2r5jo4gq.py b/SpecForge-ext/cache/compiled_kernels/7m/c7mmadjna7dltm72lxvsoktdadnw2jtxufsj2eoflefh2r5jo4gq.py new file mode 100644 index 0000000000000000000000000000000000000000..422b819400be427ee53d9463a44c8214e19a5333 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/7m/c7mmadjna7dltm72lxvsoktdadnw2jtxufsj2eoflefh2r5jo4gq.py @@ -0,0 +1,24 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 8192}, + filename=__file__, + triton_meta={'signature': {'out_ptr0': '*i32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_new_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_new_zeros_0(out_ptr0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.full([1], 0, tl.int32) + tl.store(out_ptr0 + (x0), tmp0, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/7m/e130479b4d145e755b390ab3b709dd817d1548c0596f91391e7581de8609a9eb.best_config b/SpecForge-ext/cache/compiled_kernels/7m/e130479b4d145e755b390ab3b709dd817d1548c0596f91391e7581de8609a9eb.best_config new file mode 100644 index 0000000000000000000000000000000000000000..2a95815b49cfc301dd2a3d06bb1b105b04bfbae7 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/7m/e130479b4d145e755b390ab3b709dd817d1548c0596f91391e7581de8609a9eb.best_config @@ -0,0 +1 @@ +{"XBLOCK": 256, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 11, "triton_cache_hash": "XAIV2GWX5UZL7NNOCKNWC2I6AATKI6664P6FTQPRXS2M4AR4WJWA"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/7o/c7oiol3zozs5oktlpjhg3lu46rhbgu3bqq6yibefmn2imo6bua5k.py b/SpecForge-ext/cache/compiled_kernels/7o/c7oiol3zozs5oktlpjhg3lu46rhbgu3bqq6yibefmn2imo6bua5k.py new file mode 100644 index 0000000000000000000000000000000000000000..dcd4f7b399694be9e005a5e8c7b6b7010d90c3cc --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/7o/c7oiol3zozs5oktlpjhg3lu46rhbgu3bqq6yibefmn2imo6bua5k.py @@ -0,0 +1,48 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 16384, 'r0_': 32768}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 262144, 'r0_': 2097152000}} +) +@triton.jit +def triton_red_fused_argmax_1(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 16384 + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % 2048) + x1 = xindex // 2048 + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + x3 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_2 + 32000*x0 + 65760000*x1), r0_mask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tl.store(out_ptr0 + (x3), tmp2, None) diff --git a/SpecForge-ext/cache/compiled_kernels/7z/c7z2jbjub3aupgnechol65vkvi5ruwpylzosdbqvscdyxmreb3jy.py b/SpecForge-ext/cache/compiled_kernels/7z/c7z2jbjub3aupgnechol65vkvi5ruwpylzosdbqvscdyxmreb3jy.py new file mode 100644 index 0000000000000000000000000000000000000000..8a202a4e335de762fbf7b0cff532f5aff9ad85b3 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/7z/c7z2jbjub3aupgnechol65vkvi5ruwpylzosdbqvscdyxmreb3jy.py @@ -0,0 +1,86 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 32, 'r0_': 16}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr4': '*i32', 'out_ptr5': '*i32', 'out_ptr6': '*i32', 'out_ptr7': '*i32', 'out_ptr8': '*i32', 'out_ptr9': '*i32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2', 'mutated_arg_names': ['out_ptr7', 'out_ptr9'], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2(in_ptr0, out_ptr4, out_ptr5, out_ptr6, out_ptr7, out_ptr8, out_ptr9, xnumel, r0_numel, XBLOCK : tl.constexpr): + xnumel = 32 + r0_numel = 16 + R0_BLOCK: tl.constexpr = 16 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + x0 = xindex + tmp0 = tl.load(in_ptr0 + (r0_1 + 16*x0), xmask, other=0.0) + tmp1 = tl.full([1, 1], 0, tl.int64) + tmp2 = tmp0 > tmp1 + tmp3 = tl.full([1, 1], 16384, tl.int64) + tmp4 = tmp0 < tmp3 + tmp5 = tmp2 & tmp4 + tmp6 = tmp5.to(tl.int8) + tmp7 = tmp6.to(tl.int32) + tmp8 = r0_1 + tmp9 = tmp8.to(tl.int16) + tmp10 = tl.broadcast_to(tmp7, [XBLOCK, R0_BLOCK]) + tmp11 = tl.broadcast_to(tmp9, [XBLOCK, R0_BLOCK]) + tmp12, tmp13, = triton_helpers.sort_with_index(tmp10, tmp11, None, 1, stable=True, descending=True) + tmp14 = tmp0 == tmp3 + tmp15 = tmp14.to(tl.int8) + tmp16 = tmp15.to(tl.int32) + tmp17 = tl.broadcast_to(tmp16, [XBLOCK, R0_BLOCK]) + tmp18, tmp19, = triton_helpers.sort_with_index(tmp17, tmp11, None, 1, stable=True, descending=True) + tmp20 = tmp7.to(tl.int64) + tmp21 = tl.broadcast_to(tmp20, [XBLOCK, R0_BLOCK]) + tmp23 = tl.where(xmask, tmp21, 0) + tmp24 = tl.sum(tmp23, 1)[:, None].to(tl.int64) + tmp25 = tmp16.to(tl.int64) + tmp26 = tl.broadcast_to(tmp25, [XBLOCK, R0_BLOCK]) + tmp28 = tl.where(xmask, tmp26, 0) + tmp29 = tl.sum(tmp28, 1)[:, None].to(tl.int64) + tmp30 = tmp24.to(tl.int32) + tmp31 = tmp29.to(tl.int32) + tmp32 = tmp13.to(tl.int64) + tmp33 = tmp32.to(tl.int32) + tmp34 = tmp8 < tmp30 + tmp35 = tl.full([1, 1], 16, tl.int32) + tmp36 = tl.where(tmp34, tmp33, tmp35) + tmp37 = tl.full([XBLOCK, R0_BLOCK], 17, tl.int32) + tmp38 = tmp36 + tmp37 + tmp39 = tmp36 < 0 + tmp40 = tl.where(tmp39, tmp38, tmp36) + tl.device_assert(((0 <= tmp40) & (tmp40 < 17)) | ~(xmask), "index out of bounds: 0 <= tmp40 < 17") + tmp42 = tl.full([1, 1], 1, tl.int32) + tmp43 = tmp19.to(tl.int64) + tmp44 = tmp43.to(tl.int32) + tmp45 = tmp8 < tmp31 + tmp46 = tl.where(tmp45, tmp44, tmp35) + tmp47 = tmp46 + tmp37 + tmp48 = tmp46 < 0 + tmp49 = tl.where(tmp48, tmp47, tmp46) + tl.device_assert(((0 <= tmp49) & (tmp49 < 17)) | ~(xmask), "index out of bounds: 0 <= tmp49 < 17") + tl.store(out_ptr4 + (x0), tmp30, xmask) + tl.store(out_ptr5 + (x0), tmp31, xmask) + tl.store(out_ptr6 + (r0_1 + 16*x0), tmp33, xmask) + tl.store(out_ptr7 + (tl.broadcast_to(tmp40 + 17*x0, [XBLOCK, R0_BLOCK])), tmp42, xmask) + tl.store(out_ptr8 + (r0_1 + 16*x0), tmp44, xmask) + tl.store(out_ptr9 + (tl.broadcast_to(tmp49 + 17*x0, [XBLOCK, R0_BLOCK])), tmp42, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/7z/c7z6kbhlhnd55iz3suxpzcfjhjv7p7i2zelu2nitjoegrwczbdyf.py b/SpecForge-ext/cache/compiled_kernels/7z/c7z6kbhlhnd55iz3suxpzcfjhjv7p7i2zelu2nitjoegrwczbdyf.py new file mode 100644 index 0000000000000000000000000000000000000000..52e3b721afcd85c57875e994dd3388218480787b --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/7z/c7z6kbhlhnd55iz3suxpzcfjhjv7p7i2zelu2nitjoegrwczbdyf.py @@ -0,0 +1,52 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 131072, 'r0_': 128}, + reduction_hint=ReductionHint.OUTER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_mul_sum_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_mul_sum_0(in_ptr0, in_ptr1, in_ptr2, out_ptr0, ks0, ks1, ks2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x1 = xindex // ks0 + x0 = (xindex % ks0) + _tmp13 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + x3 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = r0_2 + x1*((31 + ks1*ks2) // 32) + tmp1 = ks1*ks2 + tmp2 = tmp0 < tmp1 + tmp3 = tl.load(in_ptr0 + (x0 + ks0*(((r0_2 + x1*((31 + ks1*ks2) // 32)) % (ks1*ks2)))), r0_mask & tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp4 = tl.load(in_ptr1 + (x0 + ks0*(((r0_2 + x1*((31 + ks1*ks2) // 32)) % (ks1*ks2)))), r0_mask & tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp5 = tmp4.to(tl.float32) + tmp6 = tl.load(in_ptr2 + (((r0_2 + x1*((31 + ks1*ks2) // 32)) % (ks1*ks2))), r0_mask & tmp2 & xmask, eviction_policy='evict_last', other=0.0) + tmp7 = tmp5 * tmp6 + tmp8 = tmp7.to(tl.float32) + tmp9 = tmp3 * tmp8 + tmp10 = tl.full(tmp9.shape, 0, tmp9.dtype) + tmp11 = tl.where(tmp2, tmp9, tmp10) + tmp12 = tl.broadcast_to(tmp11, [XBLOCK, R0_BLOCK]) + tmp14 = _tmp13 + tmp12 + _tmp13 = tl.where(r0_mask & xmask, tmp14, _tmp13) + tmp13 = tl.sum(_tmp13, 1)[:, None] + tl.store(out_ptr0 + (x3), tmp13, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/7z/de291d239bdb6c33244f90904700e0423d0a8026bdcf04c4cb1f87b0edee041b.best_config b/SpecForge-ext/cache/compiled_kernels/7z/de291d239bdb6c33244f90904700e0423d0a8026bdcf04c4cb1f87b0edee041b.best_config new file mode 100644 index 0000000000000000000000000000000000000000..72b37247d185ec4d7af927732af35652bde2948b --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/7z/de291d239bdb6c33244f90904700e0423d0a8026bdcf04c4cb1f87b0edee041b.best_config @@ -0,0 +1 @@ +{"XBLOCK": 8, "num_warps": 2, "num_stages": 1, "configs_hash": "6fcabd0411a839b7b5d117b5e6638bd1b5d7bc3379312c678d803859f08278a9", "found_by_coordesc": false, "time_taken_ms": 26, "triton_cache_hash": "5P66Y2R4BAVAKI2AZ4OOKOSCRGKH7SKT5SYLTP4RXGBUCBAJIDZQ"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/aa/caa67m6yhgzsw5semsgkn3vvui6pjb2e2mxtfb5xyoo3c5qle6ao.py b/SpecForge-ext/cache/compiled_kernels/aa/caa67m6yhgzsw5semsgkn3vvui6pjb2e2mxtfb5xyoo3c5qle6ao.py new file mode 100644 index 0000000000000000000000000000000000000000..bb9716e1eec6bb7a58930a3beb5aa07dd5cbc567 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/aa/caa67m6yhgzsw5semsgkn3vvui6pjb2e2mxtfb5xyoo3c5qle6ao.py @@ -0,0 +1,320 @@ +# AOT ID: ['4_backward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/n5/cn5h4iq6wlljobax2ulslga4k6zxontovelmyztexccj4qb2xkei.py +# Topologically Sorted Source Nodes: [squeeze_2, sin, getitem_1, sin_1, squeeze, cos, getitem, cos_1], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.slice_backward, aten.add] +# Source node to ATen node mapping: +# cos => squeeze_1 +# cos_1 => unsqueeze +# getitem => index +# getitem_1 => index_1 +# sin => squeeze_3 +# sin_1 => unsqueeze_1 +# squeeze => squeeze +# squeeze_2 => squeeze_2 +# Graph fragment: +# %tangents_2 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24*s9, s24, 1]cuda:6" = PlaceHolder[target=tangents_2] +# %primals_8 : Tensor "i64[1, s9][s9, 1]cuda:6" = PlaceHolder[target=primals_8] +# %primals_6 : Tensor "bf16[1, 1, s79, s24][s96, s96, s24, 1]cuda:6" = PlaceHolder[target=primals_6] +# %primals_4 : Tensor "bf16[1, 1, s92, s24][s96, s96, s24, 1]cuda:6" = PlaceHolder[target=primals_4] +# %squeeze_2 : Tensor "bf16[1, s79, s24][s96, s24, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_6, 1), kwargs = {}) +# %squeeze_3 : Tensor "bf16[s79, s24][s24, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze_2, 0), kwargs = {}) +# %index_1 : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_3, [%primals_8]), kwargs = {}) +# %unsqueeze_1 : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:6"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index_1, 1), kwargs = {}) +# %mul_84 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24*s9, s24, 1]cuda:6"[num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, %unsqueeze_1), kwargs = {}) +# %slice_5 : Tensor "bf16[s48, s48, s9, s24 - ((s24//2))][s24*s48*s9, s24*s9, s24, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%mul_84, 3, 0, %add_96), kwargs = {}) +# %slice_6 : Tensor "bf16[s48, s48, s9, (s24//2)][s24*s48*s9, s24*s9, s24, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%mul_84, 3, %sub_72, %primals_2), kwargs = {}) +# %neg_2 : Tensor "bf16[s48, s48, s9, s24 - ((s24//2))][s48*s9*Max(1, s24 - ((s24//2))), s9*Max(1, s24 - ((s24//2))), Max(1, s24 - ((s24//2))), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%slice_5,), kwargs = {}) +# %full_default : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24*s9, s24, 1]cuda:6"[num_users=2] = call_function[target=torch.ops.aten.full.default](args = ([%primals_10, %primals_10, %primals_7, %primals_2], 0), kwargs = {dtype: torch.bfloat16, layout: torch.strided, device: cuda:6, pin_memory: False}) +# %slice_scatter_default : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24*s9, s24, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%full_default, %neg_2, 3, %floordiv, 9223372036854775807), kwargs = {}) +# %slice_scatter_default_1 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24*s9, s24, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%full_default, %slice_6, 3, 0, %floordiv), kwargs = {}) +# %add_100 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24*s9, s24, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%slice_scatter_default, %slice_scatter_default_1), kwargs = {}) +# %squeeze : Tensor "bf16[1, s92, s24][s96, s24, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_4, 1), kwargs = {}) +# %squeeze_1 : Tensor "bf16[s92, s24][s24, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze, 0), kwargs = {}) +# %index : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_1, [%primals_8]), kwargs = {}) +# %unsqueeze : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:6"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index, 1), kwargs = {}) +# %mul_85 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24*s9, s24, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, %unsqueeze), kwargs = {}) +# %add_101 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24*s9, s24, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_100, %mul_85), kwargs = {}) +# return %add_101 +triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0 = async_compile.triton('triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 16777216}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 6, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = (xindex % ks0) + x3 = xindex + x1 = ((xindex // ks0) % ks1) + tmp31 = tl.load(in_ptr0 + (x3), xmask, eviction_policy='evict_last').to(tl.float32) + tmp32 = tl.load(in_ptr1 + (x1), xmask, eviction_policy='evict_last') + tmp0 = x0 + tmp1 = ks0 // 2 + tmp2 = tmp0 >= tmp1 + tmp3 = tl.load(in_ptr0 + (x3 + (-1)*(ks0 // 2)), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp4 = tl.load(in_ptr1 + (x1), tmp2 & xmask, eviction_policy='evict_last', other=0.0) + tmp5 = tl.broadcast_to(ks2, [XBLOCK]) + tmp6 = tmp4 + tmp5 + tmp7 = tmp4 < 0 + tmp8 = tl.where(tmp7, tmp6, tmp4) + tl.device_assert(((0 <= tl.broadcast_to(tmp8, [XBLOCK])) & (tl.broadcast_to(tmp8, [XBLOCK]) < ks2)) | ~(tmp2 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp8, [XBLOCK]) < ks2") + tmp10 = tl.load(in_ptr2 + (x0 + (-1)*(ks0 // 2) + ks0*tmp8), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp11 = tmp3 * tmp10 + tmp12 = -tmp11 + tmp13 = tl.full(tmp12.shape, 0.0, tmp12.dtype) + tmp14 = tl.where(tmp2, tmp12, tmp13) + tmp15 = 0.0 + tmp16 = tl.where(tmp2, tmp14, tmp15) + tmp17 = tmp0 < tmp1 + tmp18 = tl.load(in_ptr0 + (ks0 + x3 + (-1)*(ks0 // 2)), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp19 = tl.load(in_ptr1 + (x1), tmp17 & xmask, eviction_policy='evict_last', other=0.0) + tmp20 = tl.broadcast_to(ks2, [XBLOCK]) + tmp21 = tmp19 + tmp20 + tmp22 = tmp19 < 0 + tmp23 = tl.where(tmp22, tmp21, tmp19) + tl.device_assert(((0 <= tl.broadcast_to(tmp23, [XBLOCK])) & (tl.broadcast_to(tmp23, [XBLOCK]) < ks2)) | ~(tmp17 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp23, [XBLOCK]) < ks2") + tmp25 = tl.load(in_ptr2 + (ks0 + x0 + (-1)*(ks0 // 2) + ks0*tmp23), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp26 = tmp18 * tmp25 + tmp27 = tl.full(tmp26.shape, 0.0, tmp26.dtype) + tmp28 = tl.where(tmp17, tmp26, tmp27) + tmp29 = tl.where(tmp17, tmp28, tmp15) + tmp30 = tmp16 + tmp29 + tmp33 = ks3 + tmp34 = tmp32 + tmp33 + tmp35 = tmp32 < 0 + tmp36 = tl.where(tmp35, tmp34, tmp32) + tl.device_assert(((0 <= tmp36) & (tmp36 < ks3)) | ~(xmask), "index out of bounds: 0 <= tmp36 < ks3") + tmp38 = tl.load(in_ptr3 + (x0 + ks0*tmp36), xmask, eviction_policy='evict_last').to(tl.float32) + tmp39 = tmp31 * tmp38 + tmp40 = tmp30 + tmp39 + tl.store(out_ptr0 + (x3), tmp40, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/eg/cegphctwzx57aawblx7563zff7jofvfpmllo4f2poi5emt43dc5t.py +# Topologically Sorted Source Nodes: [squeeze_2, sin, getitem_1, sin_1, squeeze, cos, getitem, cos_1], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.slice_backward, aten.add] +# Source node to ATen node mapping: +# cos => squeeze_1 +# cos_1 => unsqueeze +# getitem => index +# getitem_1 => index_1 +# sin => squeeze_3 +# sin_1 => unsqueeze_1 +# squeeze => squeeze +# squeeze_2 => squeeze_2 +# Graph fragment: +# %tangents_1 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:6" = PlaceHolder[target=tangents_1] +# %primals_8 : Tensor "i64[1, s9][s9, 1]cuda:6" = PlaceHolder[target=primals_8] +# %primals_6 : Tensor "bf16[1, 1, s79, s24][s96, s96, s24, 1]cuda:6" = PlaceHolder[target=primals_6] +# %primals_4 : Tensor "bf16[1, 1, s92, s24][s96, s96, s24, 1]cuda:6" = PlaceHolder[target=primals_4] +# %squeeze_2 : Tensor "bf16[1, s79, s24][s96, s24, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_6, 1), kwargs = {}) +# %squeeze_3 : Tensor "bf16[s79, s24][s24, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze_2, 0), kwargs = {}) +# %index_1 : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_3, [%primals_8]), kwargs = {}) +# %unsqueeze_1 : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:6"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index_1, 1), kwargs = {}) +# %squeeze : Tensor "bf16[1, s92, s24][s96, s24, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_4, 1), kwargs = {}) +# %squeeze_1 : Tensor "bf16[s92, s24][s24, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze, 0), kwargs = {}) +# %index : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_1, [%primals_8]), kwargs = {}) +# %unsqueeze : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:6"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index, 1), kwargs = {}) +# %mul_86 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:6"[num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_1, %unsqueeze_1), kwargs = {}) +# %slice_7 : Tensor "bf16[s48, s34, s9, s24 - ((s24//2))][s24*s34*s9, s24*s9, s24, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%mul_86, 3, 0, %sub_72), kwargs = {}) +# %slice_8 : Tensor "bf16[s48, s34, s9, (s24//2)][s24*s34*s9, s24*s9, s24, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%mul_86, 3, %sub_72, %primals_2), kwargs = {}) +# %neg_3 : Tensor "bf16[s48, s34, s9, s24 - ((s24//2))][s34*s9*Max(1, s24 - ((s24//2))), s9*Max(1, s24 - ((s24//2))), Max(1, s24 - ((s24//2))), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%slice_7,), kwargs = {}) +# %full_default_2 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:6"[num_users=2] = call_function[target=torch.ops.aten.full.default](args = ([%primals_10, %primals_11, %primals_7, %primals_2], 0), kwargs = {dtype: torch.bfloat16, layout: torch.strided, device: cuda:6, pin_memory: False}) +# %slice_scatter_default_2 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%full_default_2, %neg_3, 3, %floordiv, 9223372036854775807), kwargs = {}) +# %slice_scatter_default_3 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%full_default_2, %slice_8, 3, 0, %floordiv), kwargs = {}) +# %add_106 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%slice_scatter_default_2, %slice_scatter_default_3), kwargs = {}) +# %mul_87 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_1, %unsqueeze), kwargs = {}) +# %add_107 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_106, %mul_87), kwargs = {}) +# return %add_107 +triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1 = async_compile.triton('triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 67108864}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 6, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = (xindex % ks0) + x3 = xindex + x1 = ((xindex // ks0) % ks1) + tmp31 = tl.load(in_ptr0 + (x3), xmask, eviction_policy='evict_last').to(tl.float32) + tmp32 = tl.load(in_ptr1 + (x1), xmask, eviction_policy='evict_last') + tmp0 = x0 + tmp1 = ks0 // 2 + tmp2 = tmp0 >= tmp1 + tmp3 = tl.load(in_ptr0 + (x3 + (-1)*(ks0 // 2)), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp4 = tl.load(in_ptr1 + (x1), tmp2 & xmask, eviction_policy='evict_last', other=0.0) + tmp5 = tl.broadcast_to(ks2, [XBLOCK]) + tmp6 = tmp4 + tmp5 + tmp7 = tmp4 < 0 + tmp8 = tl.where(tmp7, tmp6, tmp4) + tl.device_assert(((0 <= tl.broadcast_to(tmp8, [XBLOCK])) & (tl.broadcast_to(tmp8, [XBLOCK]) < ks2)) | ~(tmp2 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp8, [XBLOCK]) < ks2") + tmp10 = tl.load(in_ptr2 + (x0 + (-1)*(ks0 // 2) + ks0*tmp8), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp11 = tmp3 * tmp10 + tmp12 = -tmp11 + tmp13 = tl.full(tmp12.shape, 0.0, tmp12.dtype) + tmp14 = tl.where(tmp2, tmp12, tmp13) + tmp15 = 0.0 + tmp16 = tl.where(tmp2, tmp14, tmp15) + tmp17 = tmp0 < tmp1 + tmp18 = tl.load(in_ptr0 + (ks0 + x3 + (-1)*(ks0 // 2)), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp19 = tl.load(in_ptr1 + (x1), tmp17 & xmask, eviction_policy='evict_last', other=0.0) + tmp20 = tl.broadcast_to(ks2, [XBLOCK]) + tmp21 = tmp19 + tmp20 + tmp22 = tmp19 < 0 + tmp23 = tl.where(tmp22, tmp21, tmp19) + tl.device_assert(((0 <= tl.broadcast_to(tmp23, [XBLOCK])) & (tl.broadcast_to(tmp23, [XBLOCK]) < ks2)) | ~(tmp17 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp23, [XBLOCK]) < ks2") + tmp25 = tl.load(in_ptr2 + (ks0 + x0 + (-1)*(ks0 // 2) + ks0*tmp23), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp26 = tmp18 * tmp25 + tmp27 = tl.full(tmp26.shape, 0.0, tmp26.dtype) + tmp28 = tl.where(tmp17, tmp26, tmp27) + tmp29 = tl.where(tmp17, tmp28, tmp15) + tmp30 = tmp16 + tmp29 + tmp33 = ks3 + tmp34 = tmp32 + tmp33 + tmp35 = tmp32 < 0 + tmp36 = tl.where(tmp35, tmp34, tmp32) + tl.device_assert(((0 <= tmp36) & (tmp36 < ks3)) | ~(xmask), "index out of bounds: 0 <= tmp36 < ks3") + tmp38 = tl.load(in_ptr3 + (x0 + ks0*tmp36), xmask, eviction_policy='evict_last').to(tl.float32) + tmp39 = tmp31 * tmp38 + tmp40 = tmp30 + tmp39 + tl.store(out_ptr0 + (x3), tmp40, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_2, primals_7, primals_10, primals_11, primals_1, primals_3, primals_5, floordiv, add_96, primals_4, primals_6, primals_8, tangents_1, tangents_2 = args + args.clear() + s24 = primals_2 + s9 = primals_7 + s48 = primals_10 + s34 = primals_11 + s92 = primals_1 + s96 = primals_3 + s79 = primals_5 + assert_size_stride(primals_4, (1, 1, s92, s24), (s96, s96, s24, 1)) + assert_size_stride(primals_6, (1, 1, s79, s24), (s96, s96, s24, 1)) + assert_size_stride(primals_8, (1, s9), (s9, 1)) + assert_size_stride(tangents_1, (s48, s34, s9, s24), (s24*s34*s9, s24*s9, s24, 1)) + assert_size_stride(tangents_2, (s48, s48, s9, s24), (s24*s48*s9, s24*s9, s24, 1)) + with torch.cuda._DeviceGuard(6): + torch.cuda.set_device(6) + buf0 = empty_strided_cuda((s48, s48, s9, s24), (s24*s48*s9, s24*s9, s24, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [squeeze_2, sin, getitem_1, sin_1, squeeze, cos, getitem, cos_1], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.slice_backward, aten.add] + triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0_xnumel = s24*s9*s48*s48 + stream6 = get_raw_stream(6) + triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0.run(tangents_2, primals_8, primals_6, primals_4, buf0, s24, s9, s79, s92, triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0_xnumel, stream=stream6) + del tangents_2 + buf1 = empty_strided_cuda((s48, s34, s9, s24), (s24*s34*s9, s24*s9, s24, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [squeeze_2, sin, getitem_1, sin_1, squeeze, cos, getitem, cos_1], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.slice_backward, aten.add] + triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1_xnumel = s24*s34*s48*s9 + stream6 = get_raw_stream(6) + triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1.run(tangents_1, primals_8, primals_6, primals_4, buf1, s24, s9, s79, s92, triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1_xnumel, stream=stream6) + del primals_4 + del primals_6 + del primals_8 + del tangents_1 + return (None, None, None, None, None, None, None, None, None, None, None, buf1, buf0, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_2 = 128 + primals_7 = 2048 + primals_10 = 8 + primals_11 = 32 + primals_1 = 2048 + primals_3 = 5245440 + primals_5 = 2048 + floordiv = 64 + add_96 = 64 + primals_4 = rand_strided((1, 1, 2048, 128), (5245440, 5245440, 128, 1), device='cuda:6', dtype=torch.bfloat16) + primals_6 = rand_strided((1, 1, 2048, 128), (5245440, 5245440, 128, 1), device='cuda:6', dtype=torch.bfloat16) + primals_8 = rand_strided((1, 2048), (2048, 1), device='cuda:6', dtype=torch.int64) + tangents_1 = rand_strided((8, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:6', dtype=torch.bfloat16) + tangents_2 = rand_strided((8, 8, 2048, 128), (2097152, 262144, 128, 1), device='cuda:6', dtype=torch.bfloat16) + fn = lambda: call([primals_2, primals_7, primals_10, primals_11, primals_1, primals_3, primals_5, floordiv, add_96, primals_4, primals_6, primals_8, tangents_1, tangents_2]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/aa/caabkjzbaqm7hrv3ypoalyjx45pdt7jezorxxk75d4cahg2knncu.py b/SpecForge-ext/cache/compiled_kernels/aa/caabkjzbaqm7hrv3ypoalyjx45pdt7jezorxxk75d4cahg2knncu.py new file mode 100644 index 0000000000000000000000000000000000000000..3ded5d78476a63bb9b5f233c6201d6eda155b5ec --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/aa/caabkjzbaqm7hrv3ypoalyjx45pdt7jezorxxk75d4cahg2knncu.py @@ -0,0 +1,89 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 1024, 'r0_': 16384}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr1': '*i32', 'out_ptr2': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1(in_ptr0, out_ptr1, out_ptr2, ks0, ks1, ks2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 16384 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % ks0) + x1 = ((xindex // ks0) % 16) + x2 = xindex // ks2 + _tmp36 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + x5 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_3 = (r0_index % 128) + r0_4 = r0_index // 128 + tmp0 = r0_3 + 128*x0 + tmp1 = ks1 + tmp2 = tmp0 < tmp1 + tmp3 = r0_4 + 128*x1 + tmp4 = r0_3 + 128*x0 + tmp5 = tmp3 >= tmp4 + tmp6 = tl.load(in_ptr0 + (tl.broadcast_to(x2, [XBLOCK, R0_BLOCK])), r0_mask & tmp2 & xmask, eviction_policy='evict_last', other=0.0) + tmp7 = tmp4 < tmp6 + tmp8 = tmp3 < tmp6 + tmp9 = tmp7 & tmp8 + tmp10 = tmp5 & tmp9 + tmp11 = tl.full([1, 1], False, tl.int1) + tmp12 = tmp11 | tmp10 + tmp13 = tl.full([1, 1], 2048, tl.int64) + tmp14 = tmp4 >= tmp13 + tmp15 = ((r0_3 + 128*x0) % 2048) + tmp16 = tmp15 < tmp6 + tmp17 = tmp14 & tmp16 + tmp18 = r0_3 + ((-1)*r0_4) + ((-128)*x1) + 128*x0 + tmp19 = (tmp18 % tmp13) + tmp20 = tl.full([1, 1], 0, tl.int32) + tmp21 = tmp19 != tmp20 + tmp22 = (libdevice.signbit(tmp19) != 0) if (tmp19).dtype is tl.float32 else tmp19 < 0 + tmp23 = (libdevice.signbit(tmp13) != 0) if (tmp13).dtype is tl.float32 else tmp13 < 0 + tmp24 = tmp22 != tmp23 + tmp25 = tmp21 & tmp24 + tmp26 = tmp19 + tmp13 + tmp27 = tl.where(tmp25, tmp26, tmp19) + tmp28 = tl.full([1, 1], 0, tl.int64) + tmp29 = tmp27 == tmp28 + tmp30 = tmp17 & tmp29 + tmp31 = tmp12 | tmp30 + tmp32 = tl.full(tmp31.shape, False, tmp31.dtype) + tmp33 = tl.where(tmp2, tmp31, tmp32) + tmp34 = tmp33.to(tl.int64) + tmp35 = tl.broadcast_to(tmp34, [XBLOCK, R0_BLOCK]) + tmp37 = _tmp36 + tmp35 + _tmp36 = tl.where(r0_mask & xmask, tmp37, _tmp36) + tmp36 = tl.sum(_tmp36, 1)[:, None] + tmp38 = tl.full([1, 1], 0, tl.int64) + tmp39 = tmp36 > tmp38 + tmp40 = tl.full([1, 1], 16384, tl.int64) + tmp41 = tmp36 < tmp40 + tmp42 = tmp39 & tmp41 + tmp43 = tmp42.to(tl.int8) + tmp44 = tmp43.to(tl.int32) + tmp45 = tmp36 == tmp40 + tmp46 = tmp45.to(tl.int8) + tmp47 = tmp46.to(tl.int32) + tl.store(out_ptr1 + (x5), tmp44, xmask) + tl.store(out_ptr2 + (x5), tmp47, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/af/cafe3dsuelcloemwu5jdikp7lqano5qxv7iayhtm5xgji2xvr4k6.py b/SpecForge-ext/cache/compiled_kernels/af/cafe3dsuelcloemwu5jdikp7lqano5qxv7iayhtm5xgji2xvr4k6.py new file mode 100644 index 0000000000000000000000000000000000000000..3905453059559f60329dec1607d9f1e09a9d8d70 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/af/cafe3dsuelcloemwu5jdikp7lqano5qxv7iayhtm5xgji2xvr4k6.py @@ -0,0 +1,47 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*i64', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_argmax_1(in_ptr0, out_ptr0, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % ks0) + x1 = xindex // ks0 + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + x3 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_2 + 32000*x0 + ks1*x1), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask & xmask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask & xmask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tl.store(out_ptr0 + (x3), tmp2, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/ai/caivmpnbt7ve3qybkm6k756igdxn3ykevul35fdg4vvgknrmprqo.py b/SpecForge-ext/cache/compiled_kernels/ai/caivmpnbt7ve3qybkm6k756igdxn3ykevul35fdg4vvgknrmprqo.py new file mode 100644 index 0000000000000000000000000000000000000000..235c25c9c1c58e320f87d60437b84befadf4a8ef --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/ai/caivmpnbt7ve3qybkm6k756igdxn3ykevul35fdg4vvgknrmprqo.py @@ -0,0 +1,66 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 16777216}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 6, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = (xindex % ks0) + x3 = xindex + x1 = ((xindex // ks0) % ks1) + tmp31 = tl.load(in_ptr0 + (x3), xmask, eviction_policy='evict_last').to(tl.float32) + tmp32 = tl.load(in_ptr1 + (x1), xmask, eviction_policy='evict_last') + tmp0 = x0 + tmp1 = ks0 // 2 + tmp2 = tmp0 >= tmp1 + tmp3 = tl.load(in_ptr0 + (x3 + (-1)*(ks0 // 2)), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp4 = tl.load(in_ptr1 + (x1), tmp2 & xmask, eviction_policy='evict_last', other=0.0) + tmp5 = tl.broadcast_to(ks2, [XBLOCK]) + tmp6 = tmp4 + tmp5 + tmp7 = tmp4 < 0 + tmp8 = tl.where(tmp7, tmp6, tmp4) + tl.device_assert(((0 <= tl.broadcast_to(tmp8, [XBLOCK])) & (tl.broadcast_to(tmp8, [XBLOCK]) < ks2)) | ~(tmp2 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp8, [XBLOCK]) < ks2") + tmp10 = tl.load(in_ptr2 + (x0 + (-1)*(ks0 // 2) + ks0*tmp8), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp11 = tmp3 * tmp10 + tmp12 = -tmp11 + tmp13 = tl.full(tmp12.shape, 0.0, tmp12.dtype) + tmp14 = tl.where(tmp2, tmp12, tmp13) + tmp15 = 0.0 + tmp16 = tl.where(tmp2, tmp14, tmp15) + tmp17 = tmp0 < tmp1 + tmp18 = tl.load(in_ptr0 + (ks0 + x3 + (-1)*(ks0 // 2)), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp19 = tl.load(in_ptr1 + (x1), tmp17 & xmask, eviction_policy='evict_last', other=0.0) + tmp20 = tl.broadcast_to(ks2, [XBLOCK]) + tmp21 = tmp19 + tmp20 + tmp22 = tmp19 < 0 + tmp23 = tl.where(tmp22, tmp21, tmp19) + tl.device_assert(((0 <= tl.broadcast_to(tmp23, [XBLOCK])) & (tl.broadcast_to(tmp23, [XBLOCK]) < ks2)) | ~(tmp17 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp23, [XBLOCK]) < ks2") + tmp25 = tl.load(in_ptr2 + (ks0 + x0 + (-1)*(ks0 // 2) + ks0*tmp23), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp26 = tmp18 * tmp25 + tmp27 = tl.full(tmp26.shape, 0.0, tmp26.dtype) + tmp28 = tl.where(tmp17, tmp26, tmp27) + tmp29 = tl.where(tmp17, tmp28, tmp15) + tmp30 = tmp16 + tmp29 + tmp33 = ks3 + tmp34 = tmp32 + tmp33 + tmp35 = tmp32 < 0 + tmp36 = tl.where(tmp35, tmp34, tmp32) + tl.device_assert(((0 <= tmp36) & (tmp36 < ks3)) | ~(xmask), "index out of bounds: 0 <= tmp36 < ks3") + tmp38 = tl.load(in_ptr3 + (x0 + ks0*tmp36), xmask, eviction_policy='evict_last').to(tl.float32) + tmp39 = tmp31 * tmp38 + tmp40 = tmp30 + tmp39 + tl.store(out_ptr0 + (x3), tmp40, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/ai/f2f38be4dfdf6b1c14c068f88a04203cd9a67c3fc07629f341d6212e60d2f52e.best_config b/SpecForge-ext/cache/compiled_kernels/ai/f2f38be4dfdf6b1c14c068f88a04203cd9a67c3fc07629f341d6212e60d2f52e.best_config new file mode 100644 index 0000000000000000000000000000000000000000..cbf4eb5ae8826a07243c88f3ee991df371ea45fb --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/ai/f2f38be4dfdf6b1c14c068f88a04203cd9a67c3fc07629f341d6212e60d2f52e.best_config @@ -0,0 +1 @@ +{"XBLOCK": 512, "num_warps": 8, "num_stages": 1, "configs_hash": "3ca5c3e34d35093f3c9ab2829a9faeebad5e61c4ca13d5ed6053d7b71ce60d5a", "found_by_coordesc": false, "time_taken_ms": 53, "triton_cache_hash": "UQSFYICF6CFQWZOBHCGZ7JZ457GHWVO6RMPN5ABNWOATFMKI6GQA"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/al/25feb68bb70a2d653884ed092be99a324d74e7c4fa2b0800c70b0c5cede23a82.best_config b/SpecForge-ext/cache/compiled_kernels/al/25feb68bb70a2d653884ed092be99a324d74e7c4fa2b0800c70b0c5cede23a82.best_config new file mode 100644 index 0000000000000000000000000000000000000000..3fc56f57c375dacceeb71ea2e7a129667d8c493f --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/al/25feb68bb70a2d653884ed092be99a324d74e7c4fa2b0800c70b0c5cede23a82.best_config @@ -0,0 +1 @@ +{"XBLOCK": 512, "num_warps": 8, "num_stages": 1, "configs_hash": "3ca5c3e34d35093f3c9ab2829a9faeebad5e61c4ca13d5ed6053d7b71ce60d5a", "found_by_coordesc": false, "time_taken_ms": 50, "triton_cache_hash": "NFABHOURJ57C2IKXWDMS2VHZ76PCVKJVD7V6CBWJDLMT5TQE5GFA"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/al/cal2r4tfyw6gic3ggqyud3nufnajx6xau2koieoitx6zg4wsiozm.py b/SpecForge-ext/cache/compiled_kernels/al/cal2r4tfyw6gic3ggqyud3nufnajx6xau2koieoitx6zg4wsiozm.py new file mode 100644 index 0000000000000000000000000000000000000000..66c9e6f24b76866bc1fab51061fcb63cbcb002bb --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/al/cal2r4tfyw6gic3ggqyud3nufnajx6xau2koieoitx6zg4wsiozm.py @@ -0,0 +1,56 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 16777216}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'ks4': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, ks4, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x4 = xindex + x2 = ((xindex // ks0) % ks1) + x0 = (xindex % ks3) + x5 = xindex // ks3 + tmp0 = tl.load(in_ptr0 + (x4), xmask, eviction_policy='evict_last').to(tl.float32) + tmp1 = tl.load(in_ptr1 + (x2), xmask, eviction_policy='evict_last') + tmp2 = ks2 + tmp3 = tmp1 + tmp2 + tmp4 = tmp1 < 0 + tmp5 = tl.where(tmp4, tmp3, tmp1) + tl.device_assert(((0 <= tmp5) & (tmp5 < ks2)) | ~(xmask), "index out of bounds: 0 <= tmp5 < ks2") + tmp7 = tl.load(in_ptr2 + (x0 + ks3*tmp5), xmask, eviction_policy='evict_last').to(tl.float32) + tmp8 = tmp0 * tmp7 + tmp9 = x0 + tmp10 = tl.full([1], 0, tl.int64) + tmp11 = tmp9 >= tmp10 + tmp12 = ks3 + (-1)*(ks3 // 2) + tmp13 = tmp9 < tmp12 + tmp14 = tl.load(in_ptr0 + (ks3*x5 + (ks3 // 2) + (x0)), tmp13 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp15 = -tmp14 + tmp16 = tl.full(tmp15.shape, 0.0, tmp15.dtype) + tmp17 = tl.where(tmp13, tmp15, tmp16) + tmp18 = tmp9 >= tmp12 + tmp19 = ks3 + tmp20 = tmp9 < tmp19 + tmp21 = tl.load(in_ptr0 + (ks3*x5 + (x0 + ((-1)*ks3) + (ks3 // 2))), tmp18 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp22 = tl.where(tmp13, tmp17, tmp21) + tmp23 = ks4 + tmp24 = tmp1 + tmp23 + tmp25 = tl.where(tmp4, tmp24, tmp1) + tl.device_assert(((0 <= tmp25) & (tmp25 < ks4)) | ~(xmask), "index out of bounds: 0 <= tmp25 < ks4") + tmp27 = tl.load(in_ptr3 + (x0 + ks3*tmp25), xmask, eviction_policy='evict_last').to(tl.float32) + tmp28 = tmp22 * tmp27 + tmp29 = tmp8 + tmp28 + tl.store(out_ptr0 + (x4), tmp29, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/aq/caqqpjwqelw7hv6k6nwpxjuod3tfnwg62cypxwyuozfme2ykuybp.py b/SpecForge-ext/cache/compiled_kernels/aq/caqqpjwqelw7hv6k6nwpxjuod3tfnwg62cypxwyuozfme2ykuybp.py new file mode 100644 index 0000000000000000000000000000000000000000..fa11f54a0c98dd94ad325776f0d3d28b6925cb20 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/aq/caqqpjwqelw7hv6k6nwpxjuod3tfnwg62cypxwyuozfme2ykuybp.py @@ -0,0 +1,307 @@ +# AOT ID: ['4_forward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3k/c3kdupo6eufhy2marzoeoddgc3okqj6m3aii3f42onl4ag77vf6u.py +# Topologically Sorted Source Nodes: [squeeze, cos, squeeze_2, sin, getitem, cos_1, getitem_1, sin_1, mul, x1, x2, neg, cat, mul_1, q_embed], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.cat, aten.add] +# Source node to ATen node mapping: +# cat => cat +# cos => squeeze_1 +# cos_1 => unsqueeze +# getitem => index +# getitem_1 => index_1 +# mul => mul_24 +# mul_1 => mul_45 +# neg => neg +# q_embed => add_54 +# sin => squeeze_3 +# sin_1 => unsqueeze_1 +# squeeze => squeeze +# squeeze_2 => squeeze_2 +# x1 => slice_1 +# x2 => slice_2 +# Graph fragment: +# %primals_12 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24, s24*s34, 1]cuda:5" = PlaceHolder[target=primals_12] +# %primals_8 : Tensor "i64[1, s9][s9, 1]cuda:5" = PlaceHolder[target=primals_8] +# %primals_4 : Tensor "bf16[1, 1, s92, s24][s96, s96, s24, 1]cuda:5" = PlaceHolder[target=primals_4] +# %primals_6 : Tensor "bf16[1, 1, s79, s24][s96, s96, s24, 1]cuda:5" = PlaceHolder[target=primals_6] +# %squeeze : Tensor "bf16[1, s92, s24][s96, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_4, 1), kwargs = {}) +# %squeeze_1 : Tensor "bf16[s92, s24][s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze, 0), kwargs = {}) +# %squeeze_2 : Tensor "bf16[1, s79, s24][s96, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_6, 1), kwargs = {}) +# %squeeze_3 : Tensor "bf16[s79, s24][s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze_2, 0), kwargs = {}) +# %index : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_1, [%primals_8]), kwargs = {}) +# %unsqueeze : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:5"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index, 1), kwargs = {}) +# %index_1 : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_3, [%primals_8]), kwargs = {}) +# %unsqueeze_1 : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:5"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index_1, 1), kwargs = {}) +# %mul_24 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24, s24*s34, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%primals_12, %unsqueeze), kwargs = {}) +# %slice_1 : Tensor "bf16[s48, s34, s9, (s24//2)][s24*s34*s9, s24, s24*s34, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%primals_12, 3, 0, %floordiv), kwargs = {}) +# %slice_2 : Tensor "bf16[s48, s34, s9, s24 - ((s24//2))][s24*s34*s9, s24, s24*s34, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%primals_12, 3, %floordiv, 9223372036854775807), kwargs = {}) +# %neg : Tensor "bf16[s48, s34, s9, s24 - ((s24//2))][s34*s9*Max(1, s24 - ((s24//2))), Max(1, s24 - ((s24//2))), s34*Max(1, s24 - ((s24//2))), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%slice_2,), kwargs = {}) +# %cat : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg, %slice_1], -1), kwargs = {}) +# %mul_45 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat, %unsqueeze_1), kwargs = {}) +# %add_54 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24, s24*s34, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_24, %mul_45), kwargs = {}) +# return %add_54 +triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0 = async_compile.triton('triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 67108864}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'ks4': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, ks4, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x4 = xindex + x2 = ((xindex // ks0) % ks1) + x0 = (xindex % ks3) + x5 = xindex // ks3 + tmp0 = tl.load(in_ptr0 + (x4), xmask, eviction_policy='evict_last').to(tl.float32) + tmp1 = tl.load(in_ptr1 + (x2), xmask, eviction_policy='evict_last') + tmp2 = ks2 + tmp3 = tmp1 + tmp2 + tmp4 = tmp1 < 0 + tmp5 = tl.where(tmp4, tmp3, tmp1) + tl.device_assert(((0 <= tmp5) & (tmp5 < ks2)) | ~(xmask), "index out of bounds: 0 <= tmp5 < ks2") + tmp7 = tl.load(in_ptr2 + (x0 + ks3*tmp5), xmask, eviction_policy='evict_last').to(tl.float32) + tmp8 = tmp0 * tmp7 + tmp9 = x0 + tmp10 = tl.full([1], 0, tl.int64) + tmp11 = tmp9 >= tmp10 + tmp12 = ks3 + (-1)*(ks3 // 2) + tmp13 = tmp9 < tmp12 + tmp14 = tl.load(in_ptr0 + (ks3*x5 + (ks3 // 2) + (x0)), tmp13 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp15 = -tmp14 + tmp16 = tl.full(tmp15.shape, 0.0, tmp15.dtype) + tmp17 = tl.where(tmp13, tmp15, tmp16) + tmp18 = tmp9 >= tmp12 + tmp19 = ks3 + tmp20 = tmp9 < tmp19 + tmp21 = tl.load(in_ptr0 + (ks3*x5 + (x0 + ((-1)*ks3) + (ks3 // 2))), tmp18 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp22 = tl.where(tmp13, tmp17, tmp21) + tmp23 = ks4 + tmp24 = tmp1 + tmp23 + tmp25 = tl.where(tmp4, tmp24, tmp1) + tl.device_assert(((0 <= tmp25) & (tmp25 < ks4)) | ~(xmask), "index out of bounds: 0 <= tmp25 < ks4") + tmp27 = tl.load(in_ptr3 + (x0 + ks3*tmp25), xmask, eviction_policy='evict_last').to(tl.float32) + tmp28 = tmp22 * tmp27 + tmp29 = tmp8 + tmp28 + tl.store(out_ptr0 + (x4), tmp29, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/n2/cn24lurjdnbidkarxbtzqpcvotiay3hsbqwsbqw73gg63elg6tak.py +# Topologically Sorted Source Nodes: [squeeze, cos, squeeze_2, sin, getitem, cos_1, getitem_1, sin_1, mul_2, x1_1, x2_1, neg_1, cat_1, mul_3, k_embed], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.cat, aten.add] +# Source node to ATen node mapping: +# cat_1 => cat_1 +# cos => squeeze_1 +# cos_1 => unsqueeze +# getitem => index +# getitem_1 => index_1 +# k_embed => add_90 +# mul_2 => mul_54 +# mul_3 => mul_75 +# neg_1 => neg_1 +# sin => squeeze_3 +# sin_1 => unsqueeze_1 +# squeeze => squeeze +# squeeze_2 => squeeze_2 +# x1_1 => slice_3 +# x2_1 => slice_4 +# Graph fragment: +# %primals_13 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24, s24*s48, 1]cuda:5" = PlaceHolder[target=primals_13] +# %primals_8 : Tensor "i64[1, s9][s9, 1]cuda:5" = PlaceHolder[target=primals_8] +# %primals_4 : Tensor "bf16[1, 1, s92, s24][s96, s96, s24, 1]cuda:5" = PlaceHolder[target=primals_4] +# %primals_6 : Tensor "bf16[1, 1, s79, s24][s96, s96, s24, 1]cuda:5" = PlaceHolder[target=primals_6] +# %squeeze : Tensor "bf16[1, s92, s24][s96, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_4, 1), kwargs = {}) +# %squeeze_1 : Tensor "bf16[s92, s24][s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze, 0), kwargs = {}) +# %squeeze_2 : Tensor "bf16[1, s79, s24][s96, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_6, 1), kwargs = {}) +# %squeeze_3 : Tensor "bf16[s79, s24][s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze_2, 0), kwargs = {}) +# %index : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_1, [%primals_8]), kwargs = {}) +# %unsqueeze : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:5"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index, 1), kwargs = {}) +# %index_1 : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_3, [%primals_8]), kwargs = {}) +# %unsqueeze_1 : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:5"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index_1, 1), kwargs = {}) +# %mul_54 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24, s24*s48, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%primals_13, %unsqueeze), kwargs = {}) +# %slice_3 : Tensor "bf16[s48, s48, s9, (s24//2)][s24*s48*s9, s24, s24*s48, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%primals_13, 3, 0, %floordiv), kwargs = {}) +# %slice_4 : Tensor "bf16[s48, s48, s9, s24 - ((s24//2))][s24*s48*s9, s24, s24*s48, 1]cuda:5"[num_users=2] = call_function[target=torch.ops.aten.slice.Tensor](args = (%primals_13, 3, %floordiv, 9223372036854775807), kwargs = {}) +# %neg_1 : Tensor "bf16[s48, s48, s9, s24 - ((s24//2))][s48*s9*Max(1, s24 - ((s24//2))), Max(1, s24 - ((s24//2))), s48*Max(1, s24 - ((s24//2))), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%slice_4,), kwargs = {}) +# %cat_1 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24*s9, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg_1, %slice_3], -1), kwargs = {}) +# %mul_75 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24*s9, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_1, %unsqueeze_1), kwargs = {}) +# %add_90 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24, s24*s48, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_54, %mul_75), kwargs = {}) +# return %add_90 +triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1 = async_compile.triton('triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 16777216}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'ks4': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, ks4, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x4 = xindex + x2 = ((xindex // ks0) % ks1) + x0 = (xindex % ks3) + x5 = xindex // ks3 + tmp0 = tl.load(in_ptr0 + (x4), xmask, eviction_policy='evict_last').to(tl.float32) + tmp1 = tl.load(in_ptr1 + (x2), xmask, eviction_policy='evict_last') + tmp2 = ks2 + tmp3 = tmp1 + tmp2 + tmp4 = tmp1 < 0 + tmp5 = tl.where(tmp4, tmp3, tmp1) + tl.device_assert(((0 <= tmp5) & (tmp5 < ks2)) | ~(xmask), "index out of bounds: 0 <= tmp5 < ks2") + tmp7 = tl.load(in_ptr2 + (x0 + ks3*tmp5), xmask, eviction_policy='evict_last').to(tl.float32) + tmp8 = tmp0 * tmp7 + tmp9 = x0 + tmp10 = tl.full([1], 0, tl.int64) + tmp11 = tmp9 >= tmp10 + tmp12 = ks3 + (-1)*(ks3 // 2) + tmp13 = tmp9 < tmp12 + tmp14 = tl.load(in_ptr0 + (ks3*x5 + (ks3 // 2) + (x0)), tmp13 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp15 = -tmp14 + tmp16 = tl.full(tmp15.shape, 0.0, tmp15.dtype) + tmp17 = tl.where(tmp13, tmp15, tmp16) + tmp18 = tmp9 >= tmp12 + tmp19 = ks3 + tmp20 = tmp9 < tmp19 + tmp21 = tl.load(in_ptr0 + (ks3*x5 + (x0 + ((-1)*ks3) + (ks3 // 2))), tmp18 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp22 = tl.where(tmp13, tmp17, tmp21) + tmp23 = ks4 + tmp24 = tmp1 + tmp23 + tmp25 = tl.where(tmp4, tmp24, tmp1) + tl.device_assert(((0 <= tmp25) & (tmp25 < ks4)) | ~(xmask), "index out of bounds: 0 <= tmp25 < ks4") + tmp27 = tl.load(in_ptr3 + (x0 + ks3*tmp25), xmask, eviction_policy='evict_last').to(tl.float32) + tmp28 = tmp22 * tmp27 + tmp29 = tmp8 + tmp28 + tl.store(out_ptr0 + (x4), tmp29, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13 = args + args.clear() + s92 = primals_1 + s24 = primals_2 + s96 = primals_3 + s79 = primals_5 + s9 = primals_7 + s38 = primals_9 + s48 = primals_10 + s34 = primals_11 + assert_size_stride(primals_4, (1, 1, s92, s24), (s96, s96, s24, 1)) + assert_size_stride(primals_6, (1, 1, s79, s24), (s96, s96, s24, 1)) + assert_size_stride(primals_8, (1, s9), (s9, 1)) + assert_size_stride(primals_12, (s48, s34, s9, s24), (s24*s34*s9, s24, s24*s34, 1)) + assert_size_stride(primals_13, (s48, s48, s9, s24), (s24*s48*s9, s24, s24*s48, 1)) + with torch.cuda._DeviceGuard(5): + torch.cuda.set_device(5) + ps0 = s24*s34 + buf0 = empty_strided_cuda((s48, s34, s9, s24), (s24*s34*s9, s24, s24*s34, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [squeeze, cos, squeeze_2, sin, getitem, cos_1, getitem_1, sin_1, mul, x1, x2, neg, cat, mul_1, q_embed], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.cat, aten.add] + triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0_xnumel = s24*s34*s48*s9 + stream5 = get_raw_stream(5) + triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0.run(primals_12, primals_8, primals_4, primals_6, buf0, ps0, s9, s92, s24, s79, triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0_xnumel, stream=stream5) + del primals_12 + ps1 = s24*s48 + buf1 = empty_strided_cuda((s48, s48, s9, s24), (s24*s48*s9, s24, s24*s48, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [squeeze, cos, squeeze_2, sin, getitem, cos_1, getitem_1, sin_1, mul_2, x1_1, x2_1, neg_1, cat_1, mul_3, k_embed], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.cat, aten.add] + triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1_xnumel = s24*s9*s48*s48 + stream5 = get_raw_stream(5) + triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1.run(primals_13, primals_8, primals_4, primals_6, buf1, ps1, s9, s92, s24, s79, triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1_xnumel, stream=stream5) + del primals_13 + return (buf0, buf1, primals_4, primals_6, primals_8, s24, s9, s48, s34, s92, s96, s79, s24 // 2, s24 + (-1)*(s24 // 2), ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = 2048 + primals_2 = 128 + primals_3 = 5245440 + primals_4 = rand_strided((1, 1, 2048, 128), (5245440, 5245440, 128, 1), device='cuda:5', dtype=torch.bfloat16) + primals_5 = 2048 + primals_6 = rand_strided((1, 1, 2048, 128), (5245440, 5245440, 128, 1), device='cuda:5', dtype=torch.bfloat16) + primals_7 = 2048 + primals_8 = rand_strided((1, 2048), (2048, 1), device='cuda:5', dtype=torch.int64) + primals_9 = 1 + primals_10 = 8 + primals_11 = 32 + primals_12 = rand_strided((8, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:5', dtype=torch.bfloat16) + primals_13 = rand_strided((8, 8, 2048, 128), (2097152, 128, 1024, 1), device='cuda:5', dtype=torch.bfloat16) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/aq/caqvrlb25w5an4txp3dstxcj6tqlcc4mprakf75e5sbtbuzd254g.py b/SpecForge-ext/cache/compiled_kernels/aq/caqvrlb25w5an4txp3dstxcj6tqlcc4mprakf75e5sbtbuzd254g.py new file mode 100644 index 0000000000000000000000000000000000000000..75fd91f55c9e03f73e7eb74c4de0eab69571dd6f --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/aq/caqvrlb25w5an4txp3dstxcj6tqlcc4mprakf75e5sbtbuzd254g.py @@ -0,0 +1,711 @@ +# AOT ID: ['13_forward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +from torch._C import _cuda_getCurrentRawStream as get_raw_stream +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/u4/cu4la2snj6taof6hjdgfl2ludclb5rxnhhncr47hr5tawo3djlhk.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %primals_2 : Tensor "bf16[2, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:7" = PlaceHolder[target=primals_2] +# %primals_4 : Tensor "bf16[2, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:7" = PlaceHolder[target=primals_4] +# %primals_6 : Tensor "bf16[2, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:7" = PlaceHolder[target=primals_6] +# %getitem_1 : Tensor "f32[2, 32, s37][32*s37, s37, 1]cuda:7" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[2, 32, s37][32*s37, s37, 1]cuda:7" = PlaceHolder[target=buf1] +# %primals_13 : Tensor "i32[2, 1, s99][s99, s99, 1]cuda:7" = PlaceHolder[target=primals_13] +# %primals_9 : Tensor "i32[2, 1, s22, s72][s22*s72, s22*s72, s72, 1]cuda:7" = PlaceHolder[target=primals_9] +# %primals_17 : Tensor "i32[2, 1, s94][s94, s94, 1]cuda:7" = PlaceHolder[target=primals_17] +# %primals_20 : Tensor "i32[2, 1, s28, s4][s28*s4, s28*s4, s4, 1]cuda:7" = PlaceHolder[target=primals_20] +# %primals_14 : Tensor "i64[2][1]cuda:7" = PlaceHolder[target=primals_14] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%primals_2, %primals_4, %primals_6, %sdpa_score0, (%primals_10, %primals_11, %primals_13, %primals_9, %primals_17, %primals_20, %primals_22, %primals_25, %primals_27, %primals_30, 128, 128, %sdpa_mask0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_14, %primals_15)), kwargs = {}) +# return %getitem +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'in_ptr9': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128*ks1, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128*ks1, 128, 1 + + ZQ = 2 + HQ = 32 + Q_LEN = ks0 + ZKV = 2 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 2 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m + 4096*idx_zq*ks0, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr9 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = ks5 + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21, primals_22, primals_23, primals_24, primals_25, primals_26, primals_27, primals_28, primals_29, primals_30 = args + args.clear() + s50 = primals_1 + s0 = primals_3 + s43 = primals_5 + s22 = primals_7 + s72 = primals_8 + s37 = primals_10 + s71 = primals_11 + s99 = primals_12 + s75 = primals_15 + s94 = primals_16 + s28 = primals_18 + s4 = primals_19 + s56 = primals_21 + s84 = primals_23 + s53 = primals_24 + s100 = primals_26 + s6 = primals_28 + s10 = primals_29 + assert_size_stride(primals_2, (2, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(primals_4, (2, 8, s0, 128), (1024*s0, 128*s0, 128, 1)) + assert_size_stride(primals_6, (2, 8, s0, 128), (1024*s0, 128*s0, 128, 1)) + assert_size_stride(primals_9, (2, 1, s22, s72), (s22*s72, s22*s72, s72, 1)) + assert_size_stride(primals_13, (2, 1, s99), (s99, s99, 1)) + assert_size_stride(primals_14, (2, ), (1, )) + assert_size_stride(primals_17, (2, 1, s94), (s94, s94, 1)) + assert_size_stride(primals_20, (2, 1, s28, s4), (s28*s4, s28*s4, s4, 1)) + assert_size_stride(primals_22, (2, 1, s56), (s56, s56, 1)) + assert_size_stride(primals_25, (2, 1, s84, s53), (s53*s84, s53*s84, s53, 1)) + assert_size_stride(primals_27, (2, 1, s100), (s100, s100, 1)) + assert_size_stride(primals_30, (2, 1, s6, s10), (s10*s6, s10*s6, s10, 1)) + with torch.cuda._DeviceGuard(7): + torch.cuda.set_device(7) + buf0 = empty_strided_cuda((2, 32, s37), (32*s37, s37, 1), torch.float32) + buf1 = empty_strided_cuda((2, 32, s37), (32*s37, s37, 1), torch.float32) + buf2 = empty_strided_cuda((2, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream7 = get_raw_stream(7) + triton_tem_fused_0.run(primals_2, primals_4, primals_6, buf0, buf1, primals_13, primals_9, primals_17, primals_20, primals_14, buf2, s37, s0, s99, s22, s72, s75, (127 + s37) // 128, 2, 32, stream=stream7) + del buf1 + return (buf2, primals_2, primals_4, primals_6, primals_9, primals_13, primals_14, primals_17, primals_20, primals_22, primals_25, primals_27, primals_30, buf2, buf0, s37, s0, s75, s22, s72, s99, s94, s28, s4, s56, s53, s84, s100, s10, s6, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = 1904 + primals_2 = rand_strided((2, 32, 1904, 128), (7798784, 128, 4096, 1), device='cuda:7', dtype=torch.bfloat16) + primals_3 = 1904 + primals_4 = rand_strided((2, 8, 1904, 128), (1949696, 243712, 128, 1), device='cuda:7', dtype=torch.bfloat16) + primals_5 = 1904 + primals_6 = rand_strided((2, 8, 1904, 128), (1949696, 243712, 128, 1), device='cuda:7', dtype=torch.bfloat16) + primals_7 = 15 + primals_8 = 15 + primals_9 = rand_strided((2, 1, 15, 15), (225, 225, 15, 1), device='cuda:7', dtype=torch.int32) + primals_10 = 1904 + primals_11 = 1904 + primals_12 = 15 + primals_13 = rand_strided((2, 1, 15), (15, 15, 1), device='cuda:7', dtype=torch.int32) + primals_14 = rand_strided((2, ), (1, ), device='cuda:7', dtype=torch.int64) + primals_15 = 1904 + primals_16 = 15 + primals_17 = rand_strided((2, 1, 15), (15, 15, 1), device='cuda:7', dtype=torch.int32) + primals_18 = 15 + primals_19 = 15 + primals_20 = rand_strided((2, 1, 15, 15), (225, 225, 15, 1), device='cuda:7', dtype=torch.int32) + primals_21 = 15 + primals_22 = rand_strided((2, 1, 15), (15, 15, 1), device='cuda:7', dtype=torch.int32) + primals_23 = 15 + primals_24 = 15 + primals_25 = rand_strided((2, 1, 15, 15), (225, 225, 15, 1), device='cuda:7', dtype=torch.int32) + primals_26 = 15 + primals_27 = rand_strided((2, 1, 15), (15, 15, 1), device='cuda:7', dtype=torch.int32) + primals_28 = 15 + primals_29 = 15 + primals_30 = rand_strided((2, 1, 15, 15), (225, 225, 15, 1), device='cuda:7', dtype=torch.int32) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21, primals_22, primals_23, primals_24, primals_25, primals_26, primals_27, primals_28, primals_29, primals_30]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/at/2dfb5ffb77d217b8298333b84d6362971879c20614915aac57601c1f150ac07b.best_config b/SpecForge-ext/cache/compiled_kernels/at/2dfb5ffb77d217b8298333b84d6362971879c20614915aac57601c1f150ac07b.best_config new file mode 100644 index 0000000000000000000000000000000000000000..0102fea510b9bf77ab661e714dfc816c066dc0d8 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/at/2dfb5ffb77d217b8298333b84d6362971879c20614915aac57601c1f150ac07b.best_config @@ -0,0 +1 @@ +{"XBLOCK": 128, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 11, "triton_cache_hash": "IK5RT3JGLTF5PMMUH32NIWB2GXNU6R6CGIZSCRHU3I65YM226KDA"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/at/cat6f3b7vbc3opxxrqwtgyrnap7msqfa5gw45bly56fm7xfzsng7.py b/SpecForge-ext/cache/compiled_kernels/at/cat6f3b7vbc3opxxrqwtgyrnap7msqfa5gw45bly56fm7xfzsng7.py new file mode 100644 index 0000000000000000000000000000000000000000..57b718fdf13315c0f38116911866361248616f8a --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/at/cat6f3b7vbc3opxxrqwtgyrnap7msqfa5gw45bly56fm7xfzsng7.py @@ -0,0 +1,27 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 2048}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr0': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_6', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused__to_copy_6(in_ptr0, out_ptr0, ks0, ks1, ks2, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = (xindex % ks0) + x1 = ((xindex // ks0) % ks1) + x2 = xindex // ks2 + tmp0 = tl.load(in_ptr0 + (x1 + x0*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1))) + x2*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1)))), xmask, eviction_policy='evict_last') + tmp1 = tmp0.to(tl.int32) + tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))) + x2*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1)))), tmp1, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/at/catnwworbo47zz5uux2qx6gtvq5zrkdmzm5qpt64msmr3cjlnoz5.py b/SpecForge-ext/cache/compiled_kernels/at/catnwworbo47zz5uux2qx6gtvq5zrkdmzm5qpt64msmr3cjlnoz5.py new file mode 100644 index 0000000000000000000000000000000000000000..4759ab7dec11697c15271245333d563f39192958 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/at/catnwworbo47zz5uux2qx6gtvq5zrkdmzm5qpt64msmr3cjlnoz5.py @@ -0,0 +1,675 @@ +# AOT ID: ['6_forward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +from torch._C import _cuda_getCurrentRawStream as get_raw_stream +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/en/cenh5uz42ng4lj7xw7veh7qtahkm73nfwpjlgreomiruz4qp4l5j.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %primals_1 : Tensor "bf16[8, 32, 2048, 128][8388608, 128, 4096, 1]cuda:0" = PlaceHolder[target=primals_1] +# %primals_2 : Tensor "bf16[8, 8, 2048, 128][2097152, 262144, 128, 1]cuda:0" = PlaceHolder[target=primals_2] +# %primals_3 : Tensor "bf16[8, 8, 2048, 128][2097152, 262144, 128, 1]cuda:0" = PlaceHolder[target=primals_3] +# %getitem_1 : Tensor "f32[8, 32, 2048][65536, 2048, 1]cuda:0" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[8, 32, 2048][65536, 2048, 1]cuda:0" = PlaceHolder[target=buf1] +# %primals_5 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:0" = PlaceHolder[target=primals_5] +# %primals_4 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:0" = PlaceHolder[target=primals_4] +# %primals_7 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:0" = PlaceHolder[target=primals_7] +# %primals_8 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:0" = PlaceHolder[target=primals_8] +# %primals_6 : Tensor "i64[8][1]cuda:0" = PlaceHolder[target=primals_6] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%primals_1, %primals_2, %primals_3, %sdpa_score0, (2048, 2048, %primals_5, %primals_4, %primals_7, %primals_8, %primals_9, %primals_10, %primals_11, %primals_12, 128, 128, %sdpa_mask0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_6,)), kwargs = {}) +# return %getitem +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'in_ptr9': '*i64', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 2097152, 262144, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 2097152, 262144, 128, 1 + + ZQ = 8 + HQ = 32 + Q_LEN = 2048 + ZKV = 8 + KV_LEN = 2048 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 8 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 256 + stride_kv_idx_m = 16 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 262144*idx_hq + 8388608*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m + 8388608*idx_zq, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr9 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12 = args + args.clear() + assert_size_stride(primals_1, (8, 32, 2048, 128), (8388608, 128, 4096, 1)) + assert_size_stride(primals_2, (8, 8, 2048, 128), (2097152, 262144, 128, 1)) + assert_size_stride(primals_3, (8, 8, 2048, 128), (2097152, 262144, 128, 1)) + assert_size_stride(primals_4, (8, 1, 16, 16), (256, 256, 16, 1)) + assert_size_stride(primals_5, (8, 1, 16), (16, 16, 1)) + assert_size_stride(primals_6, (8, ), (1, )) + assert_size_stride(primals_7, (8, 1, 16), (16, 16, 1)) + assert_size_stride(primals_8, (8, 1, 16, 16), (256, 256, 16, 1)) + assert_size_stride(primals_9, (8, 1, 16), (16, 16, 1)) + assert_size_stride(primals_10, (8, 1, 16, 16), (256, 256, 16, 1)) + assert_size_stride(primals_11, (8, 1, 16), (16, 16, 1)) + assert_size_stride(primals_12, (8, 1, 16, 16), (256, 256, 16, 1)) + with torch.cuda._DeviceGuard(0): + torch.cuda.set_device(0) + buf0 = empty_strided_cuda((8, 32, 2048), (65536, 2048, 1), torch.float32) + buf1 = empty_strided_cuda((8, 32, 2048), (65536, 2048, 1), torch.float32) + buf2 = empty_strided_cuda((8, 32, 2048, 128), (8388608, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream0 = get_raw_stream(0) + triton_tem_fused_0.run(primals_1, primals_2, primals_3, buf0, buf1, primals_5, primals_4, primals_7, primals_8, primals_6, buf2, 16, 8, 32, stream=stream0) + del buf1 + return (buf2, primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, buf2, buf0, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = rand_strided((8, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:0', dtype=torch.bfloat16) + primals_2 = rand_strided((8, 8, 2048, 128), (2097152, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16) + primals_3 = rand_strided((8, 8, 2048, 128), (2097152, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16) + primals_4 = rand_strided((8, 1, 16, 16), (256, 256, 16, 1), device='cuda:0', dtype=torch.int32) + primals_5 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:0', dtype=torch.int32) + primals_6 = rand_strided((8, ), (1, ), device='cuda:0', dtype=torch.int64) + primals_7 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:0', dtype=torch.int32) + primals_8 = rand_strided((8, 1, 16, 16), (256, 256, 16, 1), device='cuda:0', dtype=torch.int32) + primals_9 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:0', dtype=torch.int32) + primals_10 = rand_strided((8, 1, 16, 16), (256, 256, 16, 1), device='cuda:0', dtype=torch.int32) + primals_11 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:0', dtype=torch.int32) + primals_12 = rand_strided((8, 1, 16, 16), (256, 256, 16, 1), device='cuda:0', dtype=torch.int32) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/av/cavp7xan77tfr7qytfkp6sjrgkd6hvruiaqfzkeibtl5rtagscng.py b/SpecForge-ext/cache/compiled_kernels/av/cavp7xan77tfr7qytfkp6sjrgkd6hvruiaqfzkeibtl5rtagscng.py new file mode 100644 index 0000000000000000000000000000000000000000..7811a40287b74be819c04afaaca1f6e74eccb7ed --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/av/cavp7xan77tfr7qytfkp6sjrgkd6hvruiaqfzkeibtl5rtagscng.py @@ -0,0 +1,99 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 512, 'r0_': 16384}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr1': '*i32', 'out_ptr2': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'ks4': 'i64', 'ks5': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1(in_ptr0, out_ptr1, out_ptr2, ks0, ks1, ks2, ks3, ks4, ks5, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 16384 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x1 = ((xindex // ks0) % ks1) + x0 = (xindex % ks0) + x2 = xindex // ks4 + _tmp46 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + x5 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_4 = r0_index // 128 + r0_3 = (r0_index % 128) + tmp0 = r0_4 + 128*x1 + tmp1 = ks2 + tmp2 = tmp0 < tmp1 + tmp3 = r0_3 + 128*x0 + tmp4 = ks3 + tmp5 = tmp3 < tmp4 + tmp6 = tmp2 & tmp5 + tmp7 = r0_4 + 128*x1 + tmp8 = r0_3 + 128*x0 + tmp9 = tmp7 >= tmp8 + tmp10 = tl.load(in_ptr0 + (tl.broadcast_to(x2, [XBLOCK, R0_BLOCK])), r0_mask & tmp6 & xmask, eviction_policy='evict_last', other=0.0) + tmp11 = tmp8 < tmp10 + tmp12 = tmp7 < tmp10 + tmp13 = tmp11 & tmp12 + tmp14 = tmp9 & tmp13 + tmp15 = tl.full([1, 1], False, tl.int1) + tmp16 = tmp15 | tmp14 + tmp17 = tl.broadcast_to(ks5, [XBLOCK, R0_BLOCK]) + tmp18 = tmp8 >= tmp17 + tmp19 = (tmp8 % tmp17) + tmp20 = tl.full([1, 1], 0, tl.int32) + tmp21 = tmp19 != tmp20 + tmp22 = (libdevice.signbit(tmp19) != 0) if (tmp19).dtype is tl.float32 else tmp19 < 0 + tmp23 = (libdevice.signbit(tmp17) != 0) if (tmp17).dtype is tl.float32 else tmp17 < 0 + tmp24 = tmp22 != tmp23 + tmp25 = tmp21 & tmp24 + tmp26 = tmp19 + tmp17 + tmp27 = tl.where(tmp25, tmp26, tmp19) + tmp28 = tmp27 < tmp10 + tmp29 = tmp18 & tmp28 + tmp30 = r0_3 + ((-1)*r0_4) + ((-128)*x1) + 128*x0 + tmp31 = (tmp30 % tmp17) + tmp32 = tmp31 != tmp20 + tmp33 = (libdevice.signbit(tmp31) != 0) if (tmp31).dtype is tl.float32 else tmp31 < 0 + tmp34 = tmp33 != tmp23 + tmp35 = tmp32 & tmp34 + tmp36 = tmp31 + tmp17 + tmp37 = tl.where(tmp35, tmp36, tmp31) + tmp38 = tl.full([1, 1], 0, tl.int64) + tmp39 = tmp37 == tmp38 + tmp40 = tmp29 & tmp39 + tmp41 = tmp16 | tmp40 + tmp42 = tl.full(tmp41.shape, False, tmp41.dtype) + tmp43 = tl.where(tmp6, tmp41, tmp42) + tmp44 = tmp43.to(tl.int64) + tmp45 = tl.broadcast_to(tmp44, [XBLOCK, R0_BLOCK]) + tmp47 = _tmp46 + tmp45 + _tmp46 = tl.where(r0_mask & xmask, tmp47, _tmp46) + tmp46 = tl.sum(_tmp46, 1)[:, None] + tmp48 = tl.full([1, 1], 0, tl.int64) + tmp49 = tmp46 > tmp48 + tmp50 = tl.full([1, 1], 16384, tl.int64) + tmp51 = tmp46 < tmp50 + tmp52 = tmp49 & tmp51 + tmp53 = tmp52.to(tl.int8) + tmp54 = tmp53.to(tl.int32) + tmp55 = tmp46 == tmp50 + tmp56 = tmp55.to(tl.int8) + tmp57 = tmp56.to(tl.int32) + tl.store(out_ptr1 + (x5), tmp54, xmask) + tl.store(out_ptr2 + (x5), tmp57, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/bd/cbdpymknkquuerovirx6corahubfs5khfhys2add2b3c2zkuvlup.py b/SpecForge-ext/cache/compiled_kernels/bd/cbdpymknkquuerovirx6corahubfs5khfhys2add2b3c2zkuvlup.py new file mode 100644 index 0000000000000000000000000000000000000000..b80d6817a7fb0cb5036ae340f3f706547b03771e --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/bd/cbdpymknkquuerovirx6corahubfs5khfhys2add2b3c2zkuvlup.py @@ -0,0 +1,835 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_zeros_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 2097152, 262144, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 2097152, 262144, 128, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 8388608, 262144, 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 8388608, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 2097152, 262144, 128, 1 + + ZQ = 8 + HQ = 32 + HKV = 8 + Q_LEN = 2048 + ZKV = 8 + KV_LEN = 2048 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 8 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 256 + stride_kv_idx_m = 16 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 16 + stride_q_idx_h = 256 + stride_q_idx_n = 16 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 262144*off_hkv + 2097152*off_zq + tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = 2048 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr16 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp39 = (ds) + grad_scores = tmp39 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = 2048 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp40 = (qkT) + post_mod_scores = tmp40 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp41 = tl.full([1], False, tl.int1) + tmp42 = (m) + tmp43 = (n) + tmp44 = tmp42 >= tmp43 + tmp45 = tmp43.to(tl.int64) + tmp46 = (off_z) + tmp47 = tl.load(in_ptr16 + tmp46) + tmp48 = tmp45 < tmp47 + tmp49 = tmp42.to(tl.int64) + tmp50 = tmp49 < tmp47 + tmp51 = tmp48 & tmp50 + tmp52 = tmp44 & tmp51 + tmp53 = tmp41 | tmp52 + tmp54 = tl.full([1], 2048, tl.int32) + tmp55 = tmp43 >= tmp54 + tmp56 = (tmp43 % tmp54) + tmp57 = tl.full([1], 0, tl.int32) + tmp58 = tmp56 != tmp57 + tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0 + tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0 + tmp61 = tmp59 != tmp60 + tmp62 = tmp58 & tmp61 + tmp63 = tmp56 + tmp54 + tmp64 = tl.where(tmp62, tmp63, tmp56) + tmp65 = tmp64.to(tl.int64) + tmp66 = tmp65 < tmp47 + tmp67 = tmp55 & tmp66 + tmp68 = tmp43 - tmp42 + tmp69 = (tmp68 % tmp54) + tmp70 = tmp69 != tmp57 + tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0 + tmp72 = tmp71 != tmp60 + tmp73 = tmp70 & tmp72 + tmp74 = tmp69 + tmp54 + tmp75 = tl.where(tmp73, tmp74, tmp69) + tmp76 = tmp75 == tmp57 + tmp77 = tmp67 & tmp76 + tmp78 = tmp53 | tmp77 + mask_mod_output = tmp78 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp79 = (dsT) + grad_scores = tmp79 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/bi/8786fd641e91216a3bc7781055fbc9277e1637f9f319eaed8124e438ba94886f.best_config b/SpecForge-ext/cache/compiled_kernels/bi/8786fd641e91216a3bc7781055fbc9277e1637f9f319eaed8124e438ba94886f.best_config new file mode 100644 index 0000000000000000000000000000000000000000..480196fb4e04fbb2c822576a3fd81e7866bc88fc --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/bi/8786fd641e91216a3bc7781055fbc9277e1637f9f319eaed8124e438ba94886f.best_config @@ -0,0 +1 @@ +{"XBLOCK": 1, "num_warps": 2, "num_stages": 1, "configs_hash": "b6ac5ef64fddcad8fc8d2c05fa12424871fd9baa5a4158ff38ecebbafb55a4b1", "found_by_coordesc": false, "time_taken_ms": 26, "triton_cache_hash": "E2MI47QNGZ2SJDA3U3EKHN7H3EYRAANF6T7N5SFT2CZJYNBAWCNQ"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/bi/cbigeynkmamirzra5ocdek4vfe3idnh2kr2bfscbxtiim3rq5df5.py b/SpecForge-ext/cache/compiled_kernels/bi/cbigeynkmamirzra5ocdek4vfe3idnh2kr2bfscbxtiim3rq5df5.py new file mode 100644 index 0000000000000000000000000000000000000000..2c8870d344b86481be19e28ffde9a4e5dca26032 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/bi/cbigeynkmamirzra5ocdek4vfe3idnh2kr2bfscbxtiim3rq5df5.py @@ -0,0 +1,89 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 1024, 'r0_': 16384}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr1': '*i32', 'out_ptr2': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1(in_ptr0, out_ptr1, out_ptr2, ks0, ks1, ks2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 16384 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % ks0) + x1 = ((xindex // ks0) % 16) + x2 = xindex // ks2 + _tmp36 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + x5 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_3 = (r0_index % 128) + r0_4 = r0_index // 128 + tmp0 = r0_3 + 128*x0 + tmp1 = ks1 + tmp2 = tmp0 < tmp1 + tmp3 = r0_4 + 128*x1 + tmp4 = r0_3 + 128*x0 + tmp5 = tmp3 >= tmp4 + tmp6 = tl.load(in_ptr0 + (tl.broadcast_to(x2, [XBLOCK, R0_BLOCK])), r0_mask & tmp2 & xmask, eviction_policy='evict_last', other=0.0) + tmp7 = tmp4 < tmp6 + tmp8 = tmp3 < tmp6 + tmp9 = tmp7 & tmp8 + tmp10 = tmp5 & tmp9 + tmp11 = tl.full([1, 1], False, tl.int1) + tmp12 = tmp11 | tmp10 + tmp13 = tl.full([1, 1], 2048, tl.int64) + tmp14 = tmp4 >= tmp13 + tmp15 = ((r0_3 + 128*x0) % 2048) + tmp16 = tmp15 < tmp6 + tmp17 = tmp14 & tmp16 + tmp18 = r0_3 + ((-1)*r0_4) + ((-128)*x1) + 128*x0 + tmp19 = (tmp18 % tmp13) + tmp20 = tl.full([1, 1], 0, tl.int32) + tmp21 = tmp19 != tmp20 + tmp22 = (libdevice.signbit(tmp19) != 0) if (tmp19).dtype is tl.float32 else tmp19 < 0 + tmp23 = (libdevice.signbit(tmp13) != 0) if (tmp13).dtype is tl.float32 else tmp13 < 0 + tmp24 = tmp22 != tmp23 + tmp25 = tmp21 & tmp24 + tmp26 = tmp19 + tmp13 + tmp27 = tl.where(tmp25, tmp26, tmp19) + tmp28 = tl.full([1, 1], 0, tl.int64) + tmp29 = tmp27 == tmp28 + tmp30 = tmp17 & tmp29 + tmp31 = tmp12 | tmp30 + tmp32 = tl.full(tmp31.shape, False, tmp31.dtype) + tmp33 = tl.where(tmp2, tmp31, tmp32) + tmp34 = tmp33.to(tl.int64) + tmp35 = tl.broadcast_to(tmp34, [XBLOCK, R0_BLOCK]) + tmp37 = _tmp36 + tmp35 + _tmp36 = tl.where(r0_mask & xmask, tmp37, _tmp36) + tmp36 = tl.sum(_tmp36, 1)[:, None] + tmp38 = tl.full([1, 1], 0, tl.int64) + tmp39 = tmp36 > tmp38 + tmp40 = tl.full([1, 1], 16384, tl.int64) + tmp41 = tmp36 < tmp40 + tmp42 = tmp39 & tmp41 + tmp43 = tmp42.to(tl.int8) + tmp44 = tmp43.to(tl.int32) + tmp45 = tmp36 == tmp40 + tmp46 = tmp45.to(tl.int8) + tmp47 = tmp46.to(tl.int32) + tl.store(out_ptr1 + (x5), tmp44, xmask) + tl.store(out_ptr2 + (x5), tmp47, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/bi/cbiwa42zuoemjhwwkub6gypxcryfi2fbcigroxmaahfipc6cwcmf.py b/SpecForge-ext/cache/compiled_kernels/bi/cbiwa42zuoemjhwwkub6gypxcryfi2fbcigroxmaahfipc6cwcmf.py new file mode 100644 index 0000000000000000000000000000000000000000..66a6207f86ca03b8f56613ef1da4794edcd69361 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/bi/cbiwa42zuoemjhwwkub6gypxcryfi2fbcigroxmaahfipc6cwcmf.py @@ -0,0 +1,50 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 128, 'r0_': 16}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr2': '*i32', 'out_ptr3': '*i32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 1024, 'r0_': 16384}} +) +@triton.jit +def triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3(in_ptr0, out_ptr2, out_ptr3, xnumel, r0_numel, XBLOCK : tl.constexpr): + xnumel = 128 + r0_numel = 16 + R0_BLOCK: tl.constexpr = 16 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + x0 = (xindex % 16) + x1 = xindex // 16 + x3 = xindex + tmp0 = tl.load(in_ptr0 + (x0 + 17*r0_2 + 272*x1), xmask, other=0.0) + tmp1 = r0_2 + tmp2 = tmp1.to(tl.int16) + tmp3 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp4 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5, tmp6, = triton_helpers.sort_with_index(tmp3, tmp4, None, 1, stable=True, descending=True) + tmp7 = tmp0.to(tl.int64) + tmp8 = tl.broadcast_to(tmp7, [XBLOCK, R0_BLOCK]) + tmp10 = tl.where(xmask, tmp8, 0) + tmp11 = tl.sum(tmp10, 1)[:, None].to(tl.int64) + tmp12 = tmp6.to(tl.int64) + tmp13 = tmp12.to(tl.int32) + tmp14 = tmp11.to(tl.int32) + tl.store(out_ptr2 + (r0_2 + 16*x3), tmp13, xmask) + tl.store(out_ptr3 + (x3), tmp14, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/bl/4f3b743b1b08a4d33a209a79d673559deb821150d3c9096a97a7b822aab5c6a2.best_config b/SpecForge-ext/cache/compiled_kernels/bl/4f3b743b1b08a4d33a209a79d673559deb821150d3c9096a97a7b822aab5c6a2.best_config new file mode 100644 index 0000000000000000000000000000000000000000..990be040d913054ee650201b25cf2c95af882efd --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/bl/4f3b743b1b08a4d33a209a79d673559deb821150d3c9096a97a7b822aab5c6a2.best_config @@ -0,0 +1 @@ +{"XBLOCK": 8, "R0_BLOCK": 128, "num_warps": 8, "num_stages": 1, "configs_hash": "b70837e3723f218c7368cc2b49566dcd2bec3baf4c88b5e174a3f0822a6c86c0", "found_by_coordesc": false, "time_taken_ms": 142, "triton_cache_hash": "BZ2FPB5QIE7EHR6P7EPVPHR4HKS3YX3QQPIWQIT2R3EOJOAVWCGA"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/bl/cbl5w3gbsoiwosnwknfarb55lklpfntpuz6q4jjui35yi2wdwepo.py b/SpecForge-ext/cache/compiled_kernels/bl/cbl5w3gbsoiwosnwknfarb55lklpfntpuz6q4jjui35yi2wdwepo.py new file mode 100644 index 0000000000000000000000000000000000000000..1d9af826b566753c7a25e23655aa34781ae0da2f --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/bl/cbl5w3gbsoiwosnwknfarb55lklpfntpuz6q4jjui35yi2wdwepo.py @@ -0,0 +1,49 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 524288, 'r0_': 128}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr1': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 4194304, 'r0_': 268435456}} +) +@triton.jit +def triton_red_fused_zeros_0(in_ptr0, in_ptr1, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 524288 + r0_numel = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % 2048) + x1 = ((xindex // 2048) % 32) + x2 = xindex // 65536 + x4 = xindex + _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_3 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_3 + 128*x1 + 4096*x0 + 8388608*x2), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_3 + 128*x4), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = _tmp4 + tmp3 + _tmp4 = tl.where(r0_mask, tmp5, _tmp4) + tmp4 = tl.sum(_tmp4, 1)[:, None] + tmp6 = tmp4.to(tl.float32) + tmp7 = 0.0 + tmp8 = tmp6 - tmp7 + tl.store(out_ptr1 + (x4), tmp8, None) diff --git a/SpecForge-ext/cache/compiled_kernels/bl/cblj2jzv4p4sj4ui3mue4irzzpu2bcs5wme4slb5xxnkukkqtl6u.py b/SpecForge-ext/cache/compiled_kernels/bl/cblj2jzv4p4sj4ui3mue4irzzpu2bcs5wme4slb5xxnkukkqtl6u.py new file mode 100644 index 0000000000000000000000000000000000000000..c0fd4b05f2244b159cca9a1f74dc81ae2584fd93 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/bl/cblj2jzv4p4sj4ui3mue4irzzpu2bcs5wme4slb5xxnkukkqtl6u.py @@ -0,0 +1,682 @@ +# AOT ID: ['12_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/p7/cp7oi5evlluu4tzoolnivejb2h2wxctqdm2h4fyxttvr7dsyw3cu.py +# Topologically Sorted Source Nodes: [dense_mask_2], Original ATen: [aten.new_zeros] +# Source node to ATen node mapping: +# dense_mask_2 => full_default_1 +# Graph fragment: +# %full_default_1 : Tensor "i32[8, 1, ((s12 + 127)//128), (((s37 + 127)//128)) + 1][Max(1, (((s37 + 127)//128)) + 1)*Max(1, ((s12 + 127)//128)), Max(1, (((s37 + 127)//128)) + 1)*Max(1, ((s12 + 127)//128)), Max(1, (((s37 + 127)//128)) + 1), 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 1, %floordiv_3, %add_201], 0), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:0, pin_memory: False}) +# return %index_put +triton_poi_fused_new_zeros_0 = async_compile.triton('triton_poi_fused_new_zeros_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 4096}, + filename=__file__, + triton_meta={'signature': {'out_ptr0': '*i32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_new_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_new_zeros_0(out_ptr0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.full([1], 0, tl.int32) + tl.store(out_ptr0 + (x0), tmp0, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/mw/cmwjkzn63wfw2g7ct26jjfz6huarnwl6j2bgtj62piorfrmee3xb.py +# Topologically Sorted Source Nodes: [result_1, m, causal_mask, n, b, index, lt, padding_mask, index_1, lt_1, and_2, suffix_mask, remainder, index_2, padding_mask_1, and_3, and_4, sub, remainder_1, diagnol_mask, result_2, batched_outputs_2, mask_1, mask_2, mask_3, mask_block_sum, gt, lt_3, partial_blocks, partial_blocks_1, dense_mask, full_blocks, full_blocks_1, dense_mask_1], Original ATen: [aten.view, aten.arange, aten.ge, aten.index, aten.lt, aten.bitwise_and, aten.bitwise_or, aten.remainder, aten.sub, aten.eq, aten.constant_pad_nd, aten.permute, aten.sum, aten.gt, aten._to_copy] +# Source node to ATen node mapping: +# and_2 => bitwise_and_1 +# and_3 => bitwise_and_2 +# and_4 => bitwise_and_3, view_8 +# b => iota +# batched_outputs_2 => view_9 +# causal_mask => ge_2, view +# dense_mask => convert_element_type_2 +# dense_mask_1 => convert_element_type_5 +# diagnol_mask => eq_24 +# full_blocks => eq_45 +# full_blocks_1 => convert_element_type_1 +# gt => gt +# index => index +# index_1 => index_1 +# index_2 => index_2 +# lt => lt, view_1 +# lt_1 => lt_1, view_2 +# lt_3 => lt_3 +# m => iota_2 +# mask_1 => constant_pad_nd +# mask_2 => view_10 +# mask_3 => permute +# mask_block_sum => sum_1 +# n => iota_3 +# padding_mask => bitwise_and, view_3, view_4 +# padding_mask_1 => lt_2, view_6 +# partial_blocks => bitwise_and_4 +# partial_blocks_1 => convert_element_type +# remainder => remainder +# remainder_1 => remainder_1 +# result_1 => bitwise_or, full_default +# result_2 => bitwise_or_1 +# sub => sub_24, view_7 +# suffix_mask => ge_3 +# Graph fragment: +# %arg2_1 : Tensor "i64[8][1]cuda:0" = PlaceHolder[target=arg2_1] +# %sum_1 : Tensor "i64[8, 1, ((s12 + 127)//128), ((s37 + 127)//128)][(((s12 + 127)//128))*(((s37 + 127)//128)), 8*(((s12 + 127)//128))*(((s37 + 127)//128)), ((s37 + 127)//128), 1]cuda:0" = PlaceHolder[target=sum_1] +# %full_default : Tensor "b8[8, 1, 1][1, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 1, 1], False), kwargs = {dtype: torch.bool, layout: torch.strided, device: cuda:0, pin_memory: False}) +# %iota_2 : Tensor "i64[s12][1]cuda:0"[num_users=3] = call_function[target=torch.ops.prims.iota.default](args = (%arg0_1,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:0, requires_grad: False}) +# %view : Tensor "i64[s12, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%iota_2, [%arg0_1, 1]), kwargs = {}) +# %iota_3 : Tensor "i64[s37][1]cuda:0"[num_users=5] = call_function[target=torch.ops.prims.iota.default](args = (%arg1_1,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:0, requires_grad: False}) +# %ge_2 : Tensor "b8[s12, s37][Max(1, s37), 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.ge.Tensor](args = (%view, %iota_3), kwargs = {}) +# %iota : Tensor "i64[8][1]cuda:0"[num_users=3] = call_function[target=torch.ops.prims.iota.default](args = (8,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:0, requires_grad: False}) +# %index : Tensor "i64[8][1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg2_1, [%iota]), kwargs = {}) +# %view_1 : Tensor "i64[8, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%index, [8, 1]), kwargs = {}) +# %lt : Tensor "b8[8, s37][Max(1, s37), 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_3, %view_1), kwargs = {}) +# %view_4 : Tensor "b8[8, 1, s37][Max(1, s37), s37, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%lt, [8, 1, %arg1_1]), kwargs = {}) +# %index_1 : Tensor "i64[8][1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg2_1, [%iota]), kwargs = {}) +# %view_2 : Tensor "i64[8, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%index_1, [8, 1]), kwargs = {}) +# %lt_1 : Tensor "b8[8, s12][Max(1, s12), 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_2, %view_2), kwargs = {}) +# %view_3 : Tensor "b8[8, s12, 1][Max(1, s12), 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%lt_1, [8, %arg0_1, 1]), kwargs = {}) +# %bitwise_and : Tensor "b8[8, s12, s37][Max(1, s12)*Max(1, s37), Max(1, s37), 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%view_4, %view_3), kwargs = {}) +# %bitwise_and_1 : Tensor "b8[8, s12, s37][Max(1, s12)*Max(1, s37), Max(1, s37), 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%ge_2, %bitwise_and), kwargs = {}) +# %bitwise_or : Tensor "b8[8, s12, s37][Max(1, s12)*Max(1, s37), Max(1, s37), 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.bitwise_or.Tensor](args = (%full_default, %bitwise_and_1), kwargs = {}) +# %ge_3 : Tensor "b8[s37][1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.ge.Scalar](args = (%iota_3, %arg3_1), kwargs = {}) +# %remainder : Tensor "i64[s37][1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.remainder.Scalar](args = (%iota_3, %arg3_1), kwargs = {}) +# %index_2 : Tensor "i64[8][1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg2_1, [%iota]), kwargs = {}) +# %view_6 : Tensor "i64[8, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%index_2, [8, 1]), kwargs = {}) +# %lt_2 : Tensor "b8[8, s37][Max(1, s37), 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%remainder, %view_6), kwargs = {}) +# %bitwise_and_2 : Tensor "b8[8, s37][Max(1, s37), 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%ge_3, %lt_2), kwargs = {}) +# %view_8 : Tensor "b8[8, 1, s37][Max(1, s37), s37, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%bitwise_and_2, [8, 1, %arg1_1]), kwargs = {}) +# %view_7 : Tensor "i64[s12, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%iota_2, [%arg0_1, 1]), kwargs = {}) +# %sub_24 : Tensor "i64[s12, s37][Max(1, s37), 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%iota_3, %view_7), kwargs = {}) +# %remainder_1 : Tensor "i64[s12, s37][Max(1, s37), 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.remainder.Scalar](args = (%sub_24, %arg3_1), kwargs = {}) +# %eq_24 : Tensor "b8[s12, s37][Max(1, s37), 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%remainder_1, 0), kwargs = {}) +# %bitwise_and_3 : Tensor "b8[8, s12, s37][Max(1, s12)*Max(1, s37), Max(1, s37), 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%view_8, %eq_24), kwargs = {}) +# %bitwise_or_1 : Tensor "b8[8, s12, s37][Max(1, s12)*Max(1, s37), Max(1, s37), 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.bitwise_or.Tensor](args = (%bitwise_or, %bitwise_and_3), kwargs = {}) +# %view_9 : Tensor "b8[8, 1, s12, s37][Max(1, s12)*Max(1, s37), s12*Max(1, s37), Max(1, s37), 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%bitwise_or_1, [8, 1, %arg0_1, %arg1_1]), kwargs = {}) +# %constant_pad_nd : Tensor "b8[8, 1, 128*(((s12 + 127)//128)), 128*(((s37 + 127)//128))][Max(1, 128*(((s12 + 127)//128)))*Max(1, 128*(((s37 + 127)//128))), Max(1, 128*(((s12 + 127)//128)))*Max(1, 128*(((s37 + 127)//128))), Max(1, 128*(((s37 + 127)//128))), 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.constant_pad_nd.default](args = (%expand, [0, %sub_42, 0, %sub_44], 0.0), kwargs = {}) +# %view_10 : Tensor "b8[8, 1, ((s12 + 127)//128), 128, ((s37 + 127)//128), 128][Max(1, 128*(((s12 + 127)//128)))*Max(1, 128*(((s37 + 127)//128))), Max(1, 128*(((s12 + 127)//128)))*Max(1, 128*(((s37 + 127)//128))), 128*Max(1, 128*(((s37 + 127)//128))), Max(1, 128*(((s37 + 127)//128))), 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%constant_pad_nd, [8, 1, %floordiv_3, 128, %floordiv_2, 128]), kwargs = {}) +# %permute : Tensor "b8[8, 1, ((s12 + 127)//128), ((s37 + 127)//128), 128, 128][Max(1, 128*(((s12 + 127)//128)))*Max(1, 128*(((s37 + 127)//128))), Max(1, 128*(((s12 + 127)//128)))*Max(1, 128*(((s37 + 127)//128))), 128*Max(1, 128*(((s37 + 127)//128))), 128, Max(1, 128*(((s37 + 127)//128))), 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_10, [0, 1, 2, 4, 3, 5]), kwargs = {}) +# %sum_1 : Tensor "i64[8, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%permute, [-2, -1]), kwargs = {}) +# %gt : Tensor "b8[8, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {}) +# %lt_3 : Tensor "b8[8, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.lt.Scalar](args = (%sum_1, 16384), kwargs = {}) +# %bitwise_and_4 : Tensor "b8[8, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%gt, %lt_3), kwargs = {}) +# %convert_element_type : Tensor "i8[8, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%bitwise_and_4, torch.int8), kwargs = {}) +# %convert_element_type_2 : Tensor "i32[8, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%convert_element_type, torch.int32), kwargs = {}) +# %eq_45 : Tensor "b8[8, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%sum_1, 16384), kwargs = {}) +# %convert_element_type_1 : Tensor "i8[8, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%eq_45, torch.int8), kwargs = {}) +# %convert_element_type_5 : Tensor "i32[8, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%convert_element_type_1, torch.int32), kwargs = {}) +# return %sum_1,%convert_element_type_2,%convert_element_type_5 +triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1 = async_compile.triton('triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 2048, 'r0_': 16384}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr1': '*i32', 'out_ptr2': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'ks4': 'i64', 'ks5': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1(in_ptr0, out_ptr1, out_ptr2, ks0, ks1, ks2, ks3, ks4, ks5, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 16384 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x1 = ((xindex // ks0) % ks1) + x0 = (xindex % ks0) + x2 = xindex // ks4 + _tmp46 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + x5 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_4 = r0_index // 128 + r0_3 = (r0_index % 128) + tmp0 = r0_4 + 128*x1 + tmp1 = ks2 + tmp2 = tmp0 < tmp1 + tmp3 = r0_3 + 128*x0 + tmp4 = ks3 + tmp5 = tmp3 < tmp4 + tmp6 = tmp2 & tmp5 + tmp7 = r0_4 + 128*x1 + tmp8 = r0_3 + 128*x0 + tmp9 = tmp7 >= tmp8 + tmp10 = tl.load(in_ptr0 + (tl.broadcast_to(x2, [XBLOCK, R0_BLOCK])), r0_mask & tmp6 & xmask, eviction_policy='evict_last', other=0.0) + tmp11 = tmp8 < tmp10 + tmp12 = tmp7 < tmp10 + tmp13 = tmp11 & tmp12 + tmp14 = tmp9 & tmp13 + tmp15 = tl.full([1, 1], False, tl.int1) + tmp16 = tmp15 | tmp14 + tmp17 = tl.broadcast_to(ks5, [XBLOCK, R0_BLOCK]) + tmp18 = tmp8 >= tmp17 + tmp19 = (tmp8 % tmp17) + tmp20 = tl.full([1, 1], 0, tl.int32) + tmp21 = tmp19 != tmp20 + tmp22 = (libdevice.signbit(tmp19) != 0) if (tmp19).dtype is tl.float32 else tmp19 < 0 + tmp23 = (libdevice.signbit(tmp17) != 0) if (tmp17).dtype is tl.float32 else tmp17 < 0 + tmp24 = tmp22 != tmp23 + tmp25 = tmp21 & tmp24 + tmp26 = tmp19 + tmp17 + tmp27 = tl.where(tmp25, tmp26, tmp19) + tmp28 = tmp27 < tmp10 + tmp29 = tmp18 & tmp28 + tmp30 = r0_3 + ((-1)*r0_4) + ((-128)*x1) + 128*x0 + tmp31 = (tmp30 % tmp17) + tmp32 = tmp31 != tmp20 + tmp33 = (libdevice.signbit(tmp31) != 0) if (tmp31).dtype is tl.float32 else tmp31 < 0 + tmp34 = tmp33 != tmp23 + tmp35 = tmp32 & tmp34 + tmp36 = tmp31 + tmp17 + tmp37 = tl.where(tmp35, tmp36, tmp31) + tmp38 = tl.full([1, 1], 0, tl.int64) + tmp39 = tmp37 == tmp38 + tmp40 = tmp29 & tmp39 + tmp41 = tmp16 | tmp40 + tmp42 = tl.full(tmp41.shape, False, tmp41.dtype) + tmp43 = tl.where(tmp6, tmp41, tmp42) + tmp44 = tmp43.to(tl.int64) + tmp45 = tl.broadcast_to(tmp44, [XBLOCK, R0_BLOCK]) + tmp47 = _tmp46 + tmp45 + _tmp46 = tl.where(r0_mask & xmask, tmp47, _tmp46) + tmp46 = tl.sum(_tmp46, 1)[:, None] + tmp48 = tl.full([1, 1], 0, tl.int64) + tmp49 = tmp46 > tmp48 + tmp50 = tl.full([1, 1], 16384, tl.int64) + tmp51 = tmp46 < tmp50 + tmp52 = tmp49 & tmp51 + tmp53 = tmp52.to(tl.int8) + tmp54 = tmp53.to(tl.int32) + tmp55 = tmp46 == tmp50 + tmp56 = tmp55.to(tl.int8) + tmp57 = tmp56.to(tl.int32) + tl.store(out_ptr1 + (x5), tmp54, xmask) + tl.store(out_ptr2 + (x5), tmp57, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/5w/c5w735qbviioww7vfjj36tk57xo254oei3wqkunaiekkjd5pfcph.py +# Topologically Sorted Source Nodes: [num_blocks_in_row, child_3], Original ATen: [aten.sum, aten._to_copy] +# Source node to ATen node mapping: +# child_3 => convert_element_type_3 +# num_blocks_in_row => sum_2 +# Graph fragment: +# %convert_element_type_2 : Tensor "i32[8, 1, ((s12 + 127)//128), ((s37 + 127)//128)][(((s12 + 127)//128))*(((s37 + 127)//128)), 8*(((s12 + 127)//128))*(((s37 + 127)//128)), ((s37 + 127)//128), 1]cuda:0" = PlaceHolder[target=convert_element_type_2] +# %sum_2 : Tensor "i64[8, 1, ((s12 + 127)//128)][((s12 + 127)//128), 8*(((s12 + 127)//128)), 1]cuda:0" = PlaceHolder[target=sum_2] +# %sum_2 : Tensor "i64[8, 1, ((s12 + 127)//128)][Max(1, ((s12 + 127)//128)), Max(1, ((s12 + 127)//128)), 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%convert_element_type_2, [-1]), kwargs = {}) +# %convert_element_type_3 : Tensor "i32[8, 1, ((s12 + 127)//128)][Max(1, ((s12 + 127)//128)), Max(1, ((s12 + 127)//128)), 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_2, torch.int32), kwargs = {}) +# return %sum_2,%convert_element_type_3 +triton_red_fused__to_copy_sum_2 = async_compile.triton('triton_red_fused__to_copy_sum_2', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 128, 'r0_': 16}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr1': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_sum_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_sum_2(in_ptr0, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp3 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp1 = tmp0.to(tl.int64) + tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) + tmp4 = _tmp3 + tmp2 + _tmp3 = tl.where(r0_mask & xmask, tmp4, _tmp3) + tmp3 = tl.sum(_tmp3, 1)[:, None] + x2 = (xindex % ks1) + x3 = xindex // ks1 + tmp5 = tmp3.to(tl.int32) + tl.store(out_ptr1 + (x2 + x3*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1)))), tmp5, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/kb/ckb32oxhck35jxljiwtypjfxvau7giullotueug7rwhjsoqq3g2o.py +# Topologically Sorted Source Nodes: [dense_mask_2, setitem, arange_4, row_indices, col_range, unsqueeze_1, index_mask, child_4, valid_indices], Original ATen: [aten.new_zeros, aten.arange, aten.unsqueeze, aten.lt, aten._to_copy, aten.scalar_tensor, aten.where, aten.view, aten.index_put] +# Source node to ATen node mapping: +# arange_4 => iota_4 +# child_4 => convert_element_type_4 +# col_range => iota_5 +# dense_mask_2 => full_default_1 +# index_mask => lt_4 +# row_indices => unsqueeze +# setitem => full_default_2, index_put, iota_6, iota_7, unsqueeze_2, unsqueeze_3, unsqueeze_4, unsqueeze_5, unsqueeze_6 +# unsqueeze_1 => unsqueeze_1 +# valid_indices => scalar_tensor, where +# Graph fragment: +# %getitem_1 : Tensor "i64[8, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), 8*Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:0" = PlaceHolder[target=getitem_1] +# %convert_element_type_3 : Tensor "i32[8, 1, ((s12 + 127)//128)][Max(1, ((s12 + 127)//128)), Max(1, ((s12 + 127)//128)), 1]cuda:0" = PlaceHolder[target=convert_element_type_3] +# %convert_element_type_4 : Tensor "i32[8, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:0" = PlaceHolder[target=convert_element_type_4] +# %index_put : Tensor "i32[8, 1, ((s12 + 127)//128), (((s37 + 127)//128)) + 1][((((s37 + 127)//128)) + 1)*(((s12 + 127)//128)), ((((s37 + 127)//128)) + 1)*(((s12 + 127)//128)), (((s37 + 127)//128)) + 1, 1]cuda:0" = PlaceHolder[target=index_put] +# %full_default_1 : Tensor "i32[8, 1, ((s12 + 127)//128), (((s37 + 127)//128)) + 1][Max(1, (((s37 + 127)//128)) + 1)*Max(1, ((s12 + 127)//128)), Max(1, (((s37 + 127)//128)) + 1)*Max(1, ((s12 + 127)//128)), Max(1, (((s37 + 127)//128)) + 1), 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 1, %floordiv_3, %add_201], 0), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:0, pin_memory: False}) +# %iota_7 : Tensor "i64[8][1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (8,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:0, requires_grad: False}) +# %unsqueeze_4 : Tensor "i64[8, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_7, -1), kwargs = {}) +# %unsqueeze_5 : Tensor "i64[8, 1, 1][1, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_4, -1), kwargs = {}) +# %unsqueeze_6 : Tensor "i64[8, 1, 1, 1][1, 1, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_5, -1), kwargs = {}) +# %iota_6 : Tensor "i64[1][1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (1,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:0, requires_grad: False}) +# %unsqueeze_2 : Tensor "i64[1, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_6, -1), kwargs = {}) +# %unsqueeze_3 : Tensor "i64[1, 1, 1][1, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_2, -1), kwargs = {}) +# %iota_4 : Tensor "i32[((s12 + 127)//128)][1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (%floordiv_3,), kwargs = {start: 0, step: 1, dtype: torch.int32, device: cuda:0, requires_grad: False}) +# %unsqueeze : Tensor "i32[((s12 + 127)//128), 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_4, -1), kwargs = {}) +# %iota_5 : Tensor "i32[((s37 + 127)//128)][1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (%floordiv_2,), kwargs = {start: 0, step: 1, dtype: torch.int32, device: cuda:0, requires_grad: False}) +# %unsqueeze_1 : Tensor "i32[8, 1, ((s12 + 127)//128), 1][Max(1, ((s12 + 127)//128)), Max(1, ((s12 + 127)//128)), 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_3, 3), kwargs = {}) +# %lt_4 : Tensor "b8[8, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_5, %unsqueeze_1), kwargs = {}) +# %convert_element_type_4 : Tensor "i32[8, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_1, torch.int32), kwargs = {}) +# %scalar_tensor : Tensor "i32[][]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.scalar_tensor.default](args = (%floordiv_2,), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:0}) +# %where : Tensor "i32[8, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%lt_4, %convert_element_type_4, %scalar_tensor), kwargs = {}) +# %full_default_2 : Tensor "i32[8, 1, 1, 1][1, 1, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 1, 1, 1], 1), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:0, pin_memory: False}) +# %index_put : Tensor "i32[8, 1, ((s12 + 127)//128), (((s37 + 127)//128)) + 1][Max(1, (((s37 + 127)//128)) + 1)*Max(1, ((s12 + 127)//128)), Max(1, (((s37 + 127)//128)) + 1)*Max(1, ((s12 + 127)//128)), Max(1, (((s37 + 127)//128)) + 1), 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index_put_.default](args = (%full_default_1, [%unsqueeze_6, %unsqueeze_3, %unsqueeze, %where], %full_default_2), kwargs = {}) +# return %convert_element_type_4,%buf13 +triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3 = async_compile.triton('triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 2048}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i32', 'out_ptr0': '*i32', 'out_ptr1': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3', 'mutated_arg_names': ['out_ptr1'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3(in_ptr0, in_ptr1, out_ptr0, out_ptr1, ks0, ks1, ks2, ks3, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = (xindex % ks0) + x1 = ((xindex // ks0) % ks1) + x2 = xindex // ks2 + x3 = xindex // ks0 + tmp0 = tl.load(in_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))) + x2*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1)))), xmask, eviction_policy='evict_last') + tmp2 = tl.load(in_ptr1 + (x3), xmask, eviction_policy='evict_last') + tmp1 = tmp0.to(tl.int32) + tmp3 = x0 + tmp4 = tmp3 < tmp2 + tmp5 = ks0 + tmp6 = tl.where(tmp4, tmp1, tmp5) + tmp7 = 1 + ks0 + tmp8 = tmp6 + tmp7 + tmp9 = tmp6 < 0 + tmp10 = tl.where(tmp9, tmp8, tmp6) + tl.device_assert(((0 <= tmp10) & (tmp10 < 1 + (triton_helpers.div_floor_integer(127 + ks3, 128)))) | ~(xmask), "index out of bounds: 0 <= tmp10 < 1 + (triton_helpers.div_floor_integer(127 + ks3, 128))") + tmp12 = tl.full([1], 1, tl.int32) + tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))) + x2*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1)))), tmp1, xmask) + tl.store(out_ptr1 + (tmp10 + x3 + ks0*x3), tmp12, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/jz/cjzemxbhxau3m7gj3akedavekbr2ty5epn3onr4dpcg6lxfbaehb.py +# Topologically Sorted Source Nodes: [batched_outputs_3], Original ATen: [aten.slice, aten.clone] +# Source node to ATen node mapping: +# batched_outputs_3 => clone_4, slice_4 +# Graph fragment: +# %buf13 : Tensor "i32[8, 1, ((s12 + 127)//128), (((s37 + 127)//128)) + 1][((((s37 + 127)//128)) + 1)*(((s12 + 127)//128)), ((((s37 + 127)//128)) + 1)*(((s12 + 127)//128)), (((s37 + 127)//128)) + 1, 1]cuda:0" = PlaceHolder[target=buf13] +# %slice_4 : Tensor "i32[8, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, (((s37 + 127)//128)) + 1)*Max(1, ((s12 + 127)//128)), Max(1, (((s37 + 127)//128)) + 1)*Max(1, ((s12 + 127)//128)), Max(1, (((s37 + 127)//128)) + 1), 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%index_put, 3, 0, %floordiv_2), kwargs = {}) +# %clone_4 : Tensor "i32[8, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%slice_4,), kwargs = {memory_format: torch.contiguous_format}) +# return %clone_4 +triton_poi_fused_clone_slice_4 = async_compile.triton('triton_poi_fused_clone_slice_4', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 2048}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr0': '*i32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_clone_slice_4', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_clone_slice_4(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = (xindex % ks0) + x1 = xindex // ks0 + x2 = xindex + tmp0 = tl.load(in_ptr0 + (x0 + x1 + ks0*x1), xmask, eviction_policy='evict_last') + tl.store(out_ptr0 + (x2), tmp0, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wh/cwhgu2bzfksecbm4int3okkh3ph54kdms6zfad7ngpfzq6h2owww.py +# Topologically Sorted Source Nodes: [batched_outputs_3, transpose, num_blocks_in_row_2, q_num_blocks], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sum, aten._to_copy] +# Source node to ATen node mapping: +# batched_outputs_3 => clone_4, slice_4 +# num_blocks_in_row_2 => sum_4 +# q_num_blocks => convert_element_type_8 +# transpose => permute_1 +# Graph fragment: +# %clone_4 : Tensor "i32[8, 1, ((s12 + 127)//128), ((s37 + 127)//128)][(((s12 + 127)//128))*(((s37 + 127)//128)), 1, ((s37 + 127)//128), 1]cuda:0" = PlaceHolder[target=clone_4] +# %sum_4 : Tensor "i64[8, 1, ((s37 + 127)//128)][((s37 + 127)//128), 8*(((s37 + 127)//128)), 1]cuda:0" = PlaceHolder[target=sum_4] +# %slice_4 : Tensor "i32[8, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, (((s37 + 127)//128)) + 1)*Max(1, ((s12 + 127)//128)), Max(1, (((s37 + 127)//128)) + 1)*Max(1, ((s12 + 127)//128)), Max(1, (((s37 + 127)//128)) + 1), 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%index_put, 3, 0, %floordiv_2), kwargs = {}) +# %clone_4 : Tensor "i32[8, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%slice_4,), kwargs = {memory_format: torch.contiguous_format}) +# %permute_1 : Tensor "i32[8, 1, ((s37 + 127)//128), ((s12 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), 1, Max(1, ((s37 + 127)//128))]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.permute.default](args = (%clone_4, [0, 1, 3, 2]), kwargs = {}) +# %sum_4 : Tensor "i64[8, 1, ((s37 + 127)//128)][Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%permute_1, [-1]), kwargs = {}) +# %convert_element_type_8 : Tensor "i32[8, 1, ((s37 + 127)//128)][Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_4, torch.int32), kwargs = {}) +# return %sum_4,%convert_element_type_8 +triton_red_fused__to_copy_clone_slice_sum_transpose_5 = async_compile.triton('triton_red_fused__to_copy_clone_slice_sum_transpose_5', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 128, 'r0_': 16}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr1': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_clone_slice_sum_transpose_5', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_clone_slice_sum_transpose_5(in_ptr0, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % ks0) + x1 = xindex // ks0 + _tmp3 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + x3 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (x0 + ks0*r0_2 + ks0*ks1*x1), r0_mask & xmask, eviction_policy='evict_last', other=0.0) + tmp1 = tmp0.to(tl.int64) + tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) + tmp4 = _tmp3 + tmp2 + _tmp3 = tl.where(r0_mask & xmask, tmp4, _tmp3) + tmp3 = tl.sum(_tmp3, 1)[:, None] + tmp5 = tmp3.to(tl.int32) + tl.store(out_ptr1 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp5, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/at/cat6f3b7vbc3opxxrqwtgyrnap7msqfa5gw45bly56fm7xfzsng7.py +# Topologically Sorted Source Nodes: [q_indices], Original ATen: [aten._to_copy] +# Source node to ATen node mapping: +# q_indices => clone_6, convert_element_type_9 +# Graph fragment: +# %getitem_5 : Tensor "i64[8, 1, ((s37 + 127)//128), ((s12 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), 1, Max(1, ((s37 + 127)//128))]cuda:0" = PlaceHolder[target=getitem_5] +# %convert_element_type_9 : Tensor "i32[8, 1, ((s37 + 127)//128), ((s12 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), 1, Max(1, ((s37 + 127)//128))]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_5, torch.int32), kwargs = {}) +# %clone_6 : Tensor "i32[8, 1, ((s37 + 127)//128), ((s12 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128)), 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%convert_element_type_9,), kwargs = {memory_format: torch.contiguous_format}) +# return %clone_6 +triton_poi_fused__to_copy_6 = async_compile.triton('triton_poi_fused__to_copy_6', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 2048}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr0': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_6', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused__to_copy_6(in_ptr0, out_ptr0, ks0, ks1, ks2, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = (xindex % ks0) + x1 = ((xindex // ks0) % ks1) + x2 = xindex // ks2 + tmp0 = tl.load(in_ptr0 + (x1 + x0*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1))) + x2*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1)))), xmask, eviction_policy='evict_last') + tmp1 = tmp0.to(tl.int32) + tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))) + x2*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1)))), tmp1, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1, arg2_1, arg3_1 = args + args.clear() + s12 = arg0_1 + s37 = arg1_1 + s21 = arg3_1 + assert_size_stride(arg2_1, (8, ), (1, )) + with torch.cuda._DeviceGuard(0): + torch.cuda.set_device(0) + buf12 = empty_strided_cuda((8, 1, (127 + s12) // 128, 1 + ((127 + s37) // 128)), (((127 + s12) // 128)*((127 + s37) // 128) + ((127 + s12) // 128), ((127 + s12) // 128)*((127 + s37) // 128) + ((127 + s12) // 128), 1 + ((127 + s37) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [dense_mask_2], Original ATen: [aten.new_zeros] + triton_poi_fused_new_zeros_0_xnumel = 8*((127 + s12) // 128) + 8*((127 + s12) // 128)*((127 + s37) // 128) + stream0 = get_raw_stream(0) + triton_poi_fused_new_zeros_0.run(buf12, triton_poi_fused_new_zeros_0_xnumel, stream=stream0) + buf21 = empty_strided_cuda((8, 1, (127 + s12) // 128, 1 + ((127 + s37) // 128)), (((127 + s12) // 128)*((127 + s37) // 128) + ((127 + s12) // 128), ((127 + s12) // 128)*((127 + s37) // 128) + ((127 + s12) // 128), 1 + ((127 + s37) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [dense_mask_4], Original ATen: [aten.new_zeros] + triton_poi_fused_new_zeros_0_xnumel = 8*((127 + s12) // 128) + 8*((127 + s12) // 128)*((127 + s37) // 128) + stream0 = get_raw_stream(0) + triton_poi_fused_new_zeros_0.run(buf21, triton_poi_fused_new_zeros_0_xnumel, stream=stream0) + ps0 = (127 + s37) // 128 + ps1 = (127 + s12) // 128 + ps2 = ((127 + s12) // 128)*((127 + s37) // 128) + buf1 = empty_strided_cuda((8, 1, (127 + s12) // 128, (127 + s37) // 128), (((127 + s12) // 128)*((127 + s37) // 128), 8*((127 + s12) // 128)*((127 + s37) // 128), (127 + s37) // 128, 1), torch.int32) + buf5 = empty_strided_cuda((8, 1, (127 + s12) // 128, (127 + s37) // 128), (((127 + s12) // 128)*((127 + s37) // 128), 8*((127 + s12) // 128)*((127 + s37) // 128), (127 + s37) // 128, 1), torch.int32) + # Topologically Sorted Source Nodes: [result_1, m, causal_mask, n, b, index, lt, padding_mask, index_1, lt_1, and_2, suffix_mask, remainder, index_2, padding_mask_1, and_3, and_4, sub, remainder_1, diagnol_mask, result_2, batched_outputs_2, mask_1, mask_2, mask_3, mask_block_sum, gt, lt_3, partial_blocks, partial_blocks_1, dense_mask, full_blocks, full_blocks_1, dense_mask_1], Original ATen: [aten.view, aten.arange, aten.ge, aten.index, aten.lt, aten.bitwise_and, aten.bitwise_or, aten.remainder, aten.sub, aten.eq, aten.constant_pad_nd, aten.permute, aten.sum, aten.gt, aten._to_copy] + triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1_xnumel = 8*((127 + s12) // 128)*((127 + s37) // 128) + stream0 = get_raw_stream(0) + triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1.run(arg2_1, buf1, buf5, ps0, ps1, s12, s37, ps2, s21, triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1_xnumel, 16384, stream=stream0) + del arg2_1 + buf10 = empty_strided_cuda((8, 1, (127 + s12) // 128), (max(1, (127 + s12) // 128), max(1, (127 + s12) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [num_blocks_in_row, child_3], Original ATen: [aten.sum, aten._to_copy] + triton_red_fused__to_copy_sum_2_xnumel = 8*((127 + s12) // 128) + triton_red_fused__to_copy_sum_2_r0_numel = (127 + s37) // 128 + stream0 = get_raw_stream(0) + triton_red_fused__to_copy_sum_2.run(buf1, buf10, ps0, ps1, triton_red_fused__to_copy_sum_2_xnumel, triton_red_fused__to_copy_sum_2_r0_numel, stream=stream0) + buf19 = empty_strided_cuda((8, 1, (127 + s12) // 128), (max(1, (127 + s12) // 128), max(1, (127 + s12) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [num_blocks_in_row_1, child_7], Original ATen: [aten.sum, aten._to_copy] + triton_red_fused__to_copy_sum_2_xnumel = 8*((127 + s12) // 128) + triton_red_fused__to_copy_sum_2_r0_numel = (127 + s37) // 128 + stream0 = get_raw_stream(0) + triton_red_fused__to_copy_sum_2.run(buf5, buf19, ps0, ps1, triton_red_fused__to_copy_sum_2_xnumel, triton_red_fused__to_copy_sum_2_r0_numel, stream=stream0) + # Topologically Sorted Source Nodes: [gt, lt_3, partial_blocks, partial_blocks_1, dense_mask, col_indices], Original ATen: [aten.gt, aten.lt, aten.bitwise_and, aten._to_copy, aten.sort] + buf2 = torch.ops.aten.sort.stable(buf1, stable=True, dim=3, descending=True) + del buf1 + buf4 = buf2[1] + assert_size_stride(buf4, (8, 1, (127 + s12) // 128, (127 + s37) // 128), (max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), 8*max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), 'torch.ops.aten.sort.stable') + assert_alignment(buf4, 16, 'torch.ops.aten.sort.stable') + del buf2 + buf11 = empty_strided_cuda((8, 1, (127 + s12) // 128, (127 + s37) // 128), (max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [dense_mask_2, setitem, arange_4, row_indices, col_range, unsqueeze_1, index_mask, child_4, valid_indices], Original ATen: [aten.new_zeros, aten.arange, aten.unsqueeze, aten.lt, aten._to_copy, aten.scalar_tensor, aten.where, aten.view, aten.index_put] + triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3_xnumel = 8*((127 + s12) // 128)*((127 + s37) // 128) + stream0 = get_raw_stream(0) + triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3.run(buf4, buf10, buf11, buf12, ps0, ps1, ps2, s37, triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3_xnumel, stream=stream0) + del buf4 + buf14 = empty_strided_cuda((8, 1, (127 + s12) // 128, (127 + s37) // 128), (((127 + s12) // 128)*((127 + s37) // 128), 1, (127 + s37) // 128, 1), torch.int32) + # Topologically Sorted Source Nodes: [batched_outputs_3], Original ATen: [aten.slice, aten.clone] + triton_poi_fused_clone_slice_4_xnumel = 8*((127 + s12) // 128)*((127 + s37) // 128) + stream0 = get_raw_stream(0) + triton_poi_fused_clone_slice_4.run(buf12, buf14, ps0, triton_poi_fused_clone_slice_4_xnumel, stream=stream0) + del buf12 + buf32 = empty_strided_cuda((8, 1, (127 + s37) // 128), (max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [batched_outputs_3, transpose, num_blocks_in_row_2, q_num_blocks], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sum, aten._to_copy] + triton_red_fused__to_copy_clone_slice_sum_transpose_5_xnumel = 8*((127 + s37) // 128) + triton_red_fused__to_copy_clone_slice_sum_transpose_5_r0_numel = (127 + s12) // 128 + stream0 = get_raw_stream(0) + triton_red_fused__to_copy_clone_slice_sum_transpose_5.run(buf14, buf32, ps0, ps1, triton_red_fused__to_copy_clone_slice_sum_transpose_5_xnumel, triton_red_fused__to_copy_clone_slice_sum_transpose_5_r0_numel, stream=stream0) + # Topologically Sorted Source Nodes: [full_blocks, full_blocks_1, dense_mask_1, col_indices_1], Original ATen: [aten.eq, aten._to_copy, aten.sort] + buf6 = torch.ops.aten.sort.stable(buf5, stable=True, dim=3, descending=True) + del buf5 + buf8 = buf6[1] + assert_size_stride(buf8, (8, 1, (127 + s12) // 128, (127 + s37) // 128), (max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), 8*max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), 'torch.ops.aten.sort.stable') + assert_alignment(buf8, 16, 'torch.ops.aten.sort.stable') + del buf6 + buf20 = empty_strided_cuda((8, 1, (127 + s12) // 128, (127 + s37) // 128), (max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [dense_mask_4, setitem_1, arange_6, row_indices_1, col_range_1, unsqueeze_3, index_mask_1, child_8, valid_indices_1], Original ATen: [aten.new_zeros, aten.arange, aten.unsqueeze, aten.lt, aten._to_copy, aten.scalar_tensor, aten.where, aten.view, aten.index_put] + triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3_xnumel = 8*((127 + s12) // 128)*((127 + s37) // 128) + stream0 = get_raw_stream(0) + triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3.run(buf8, buf19, buf20, buf21, ps0, ps1, ps2, s37, triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3_xnumel, stream=stream0) + del buf8 + buf23 = empty_strided_cuda((8, 1, (127 + s12) // 128, (127 + s37) // 128), (((127 + s12) // 128)*((127 + s37) // 128), 1, (127 + s37) // 128, 1), torch.int32) + # Topologically Sorted Source Nodes: [batched_outputs_5], Original ATen: [aten.slice, aten.clone] + triton_poi_fused_clone_slice_4_xnumel = 8*((127 + s12) // 128)*((127 + s37) // 128) + stream0 = get_raw_stream(0) + triton_poi_fused_clone_slice_4.run(buf21, buf23, ps0, triton_poi_fused_clone_slice_4_xnumel, stream=stream0) + del buf21 + buf29 = empty_strided_cuda((8, 1, (127 + s37) // 128), (max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [batched_outputs_5, transpose_1, num_blocks_in_row_3, full_q_num_blocks], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sum, aten._to_copy] + triton_red_fused__to_copy_clone_slice_sum_transpose_5_xnumel = 8*((127 + s37) // 128) + triton_red_fused__to_copy_clone_slice_sum_transpose_5_r0_numel = (127 + s12) // 128 + stream0 = get_raw_stream(0) + triton_red_fused__to_copy_clone_slice_sum_transpose_5.run(buf23, buf29, ps0, ps1, triton_red_fused__to_copy_clone_slice_sum_transpose_5_xnumel, triton_red_fused__to_copy_clone_slice_sum_transpose_5_r0_numel, stream=stream0) + # Topologically Sorted Source Nodes: [batched_outputs_3, transpose, col_indices_2], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sort] + buf15 = torch.ops.aten.sort.stable(reinterpret_tensor(buf14, (8, 1, (127 + s37) // 128, (127 + s12) // 128), (((127 + s12) // 128)*((127 + s37) // 128), 0, 1, (127 + s37) // 128), 0), stable=True, dim=3, descending=True) + del buf14 + buf17 = buf15[1] + assert_size_stride(buf17, (8, 1, (127 + s37) // 128, (127 + s12) // 128), (max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), 1, max(1, (127 + s37) // 128)), 'torch.ops.aten.sort.stable') + assert_alignment(buf17, 16, 'torch.ops.aten.sort.stable') + del buf15 + buf30 = empty_strided_cuda((8, 1, (127 + s37) // 128, (127 + s12) // 128), (max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s12) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [q_indices], Original ATen: [aten._to_copy] + triton_poi_fused__to_copy_6_xnumel = 8*((127 + s12) // 128)*((127 + s37) // 128) + stream0 = get_raw_stream(0) + triton_poi_fused__to_copy_6.run(buf17, buf30, ps1, ps0, ps2, triton_poi_fused__to_copy_6_xnumel, stream=stream0) + del buf17 + # Topologically Sorted Source Nodes: [batched_outputs_5, transpose_1, col_indices_3], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sort] + buf24 = torch.ops.aten.sort.stable(reinterpret_tensor(buf23, (8, 1, (127 + s37) // 128, (127 + s12) // 128), (((127 + s12) // 128)*((127 + s37) // 128), 0, 1, (127 + s37) // 128), 0), stable=True, dim=3, descending=True) + del buf23 + buf26 = buf24[1] + assert_size_stride(buf26, (8, 1, (127 + s37) // 128, (127 + s12) // 128), (max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), 1, max(1, (127 + s37) // 128)), 'torch.ops.aten.sort.stable') + assert_alignment(buf26, 16, 'torch.ops.aten.sort.stable') + del buf24 + buf27 = empty_strided_cuda((8, 1, (127 + s37) // 128, (127 + s12) // 128), (max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s12) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [full_q_indices], Original ATen: [aten._to_copy] + triton_poi_fused__to_copy_6_xnumel = 8*((127 + s12) // 128)*((127 + s37) // 128) + stream0 = get_raw_stream(0) + triton_poi_fused__to_copy_6.run(buf26, buf27, ps1, ps0, ps2, triton_poi_fused__to_copy_6_xnumel, stream=stream0) + del buf26 + return (buf27, buf29, buf30, buf32, buf20, buf19, buf11, buf10, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = 2009 + arg1_1 = 2009 + arg2_1 = rand_strided((8, ), (1, ), device='cuda:0', dtype=torch.int64) + arg3_1 = 2009 + fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/bn/6c249a631d8e00fcdeac92967dc92fbca5a0de3185904916fe1c6f7d7c53eb18.best_config b/SpecForge-ext/cache/compiled_kernels/bn/6c249a631d8e00fcdeac92967dc92fbca5a0de3185904916fe1c6f7d7c53eb18.best_config new file mode 100644 index 0000000000000000000000000000000000000000..39aa06f1122c6eb2904338d2578102fd0e126a89 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/bn/6c249a631d8e00fcdeac92967dc92fbca5a0de3185904916fe1c6f7d7c53eb18.best_config @@ -0,0 +1 @@ +{"XBLOCK": 1, "num_warps": 2, "num_stages": 1, "configs_hash": "b6ac5ef64fddcad8fc8d2c05fa12424871fd9baa5a4158ff38ecebbafb55a4b1", "found_by_coordesc": false, "time_taken_ms": 26, "triton_cache_hash": "G2LU7LIHIOEHQSWVLFBJATACJ76YHM672CUBUDGJGAJUEQVWVOFQ"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/bn/cbnlhjvdilhpyvbq45tudsctw5loi5h3hcwqzatfbvx3vqngs73c.py b/SpecForge-ext/cache/compiled_kernels/bn/cbnlhjvdilhpyvbq45tudsctw5loi5h3hcwqzatfbvx3vqngs73c.py new file mode 100644 index 0000000000000000000000000000000000000000..c770e2fea4461765223fe9f8a1e8785c1bd2bf31 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/bn/cbnlhjvdilhpyvbq45tudsctw5loi5h3hcwqzatfbvx3vqngs73c.py @@ -0,0 +1,49 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 256, 'r0_': 16}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr2': '*i32', 'out_ptr3': '*i32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3(in_ptr0, out_ptr2, out_ptr3, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 16 + R0_BLOCK: tl.constexpr = 16 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + x0 = (xindex % ks0) + x1 = xindex // ks0 + x3 = xindex + tmp0 = tl.load(in_ptr0 + (r0_2 + x0 + 16*x1 + ks0*r0_2 + 16*ks0*x1), xmask, eviction_policy='evict_last', other=0.0) + tmp1 = r0_2 + tmp2 = tmp1.to(tl.int16) + tmp3 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp4 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5, tmp6, = triton_helpers.sort_with_index(tmp3, tmp4, None, 1, stable=True, descending=True) + tmp7 = tmp0.to(tl.int64) + tmp8 = tl.broadcast_to(tmp7, [XBLOCK, R0_BLOCK]) + tmp10 = tl.where(xmask, tmp8, 0) + tmp11 = tl.sum(tmp10, 1)[:, None].to(tl.int64) + tmp12 = tmp6.to(tl.int64) + tmp13 = tmp12.to(tl.int32) + tmp14 = tmp11.to(tl.int32) + tl.store(out_ptr2 + (r0_2 + 16*x0 + 16*x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp13, xmask) + tl.store(out_ptr3 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp14, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/bq/cbqbhzf3233oxueuh2qyo5rpsej4in3zlepofypgyrmhyyfbziwx.py b/SpecForge-ext/cache/compiled_kernels/bq/cbqbhzf3233oxueuh2qyo5rpsej4in3zlepofypgyrmhyyfbziwx.py new file mode 100644 index 0000000000000000000000000000000000000000..8fe2bbc5a1c4199e1df1846daf6f47da4f1d2647 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/bq/cbqbhzf3233oxueuh2qyo5rpsej4in3zlepofypgyrmhyyfbziwx.py @@ -0,0 +1,552 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'in_ptr9': '*i64', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 2097152, 262144, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 2097152, 262144, 128, 1 + + ZQ = 2 + HQ = 32 + Q_LEN = 2048 + ZKV = 2 + KV_LEN = 2048 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 2 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 256 + stride_kv_idx_m = 16 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 262144*idx_hq + 8388608*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m + 8388608*idx_zq, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr9 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/by/cbyf6tnxdurjeuzcpf5qwfctjpqdrbul2dx4g37h6xtdepaj3zuo.py b/SpecForge-ext/cache/compiled_kernels/by/cbyf6tnxdurjeuzcpf5qwfctjpqdrbul2dx4g37h6xtdepaj3zuo.py new file mode 100644 index 0000000000000000000000000000000000000000..d3a3a419e7ab6473198d1fde20da57634229041f --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/by/cbyf6tnxdurjeuzcpf5qwfctjpqdrbul2dx4g37h6xtdepaj3zuo.py @@ -0,0 +1,99 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 512, 'r0_': 16384}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr1': '*i32', 'out_ptr2': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'ks4': 'i64', 'ks5': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1(in_ptr0, out_ptr1, out_ptr2, ks0, ks1, ks2, ks3, ks4, ks5, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 16384 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x1 = ((xindex // ks0) % ks1) + x0 = (xindex % ks0) + x2 = xindex // ks4 + _tmp46 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + x5 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_4 = r0_index // 128 + r0_3 = (r0_index % 128) + tmp0 = r0_4 + 128*x1 + tmp1 = ks2 + tmp2 = tmp0 < tmp1 + tmp3 = r0_3 + 128*x0 + tmp4 = ks3 + tmp5 = tmp3 < tmp4 + tmp6 = tmp2 & tmp5 + tmp7 = r0_4 + 128*x1 + tmp8 = r0_3 + 128*x0 + tmp9 = tmp7 >= tmp8 + tmp10 = tl.load(in_ptr0 + (tl.broadcast_to(x2, [XBLOCK, R0_BLOCK])), r0_mask & tmp6 & xmask, eviction_policy='evict_last', other=0.0) + tmp11 = tmp8 < tmp10 + tmp12 = tmp7 < tmp10 + tmp13 = tmp11 & tmp12 + tmp14 = tmp9 & tmp13 + tmp15 = tl.full([1, 1], False, tl.int1) + tmp16 = tmp15 | tmp14 + tmp17 = tl.broadcast_to(ks5, [XBLOCK, R0_BLOCK]) + tmp18 = tmp8 >= tmp17 + tmp19 = (tmp8 % tmp17) + tmp20 = tl.full([1, 1], 0, tl.int32) + tmp21 = tmp19 != tmp20 + tmp22 = (libdevice.signbit(tmp19) != 0) if (tmp19).dtype is tl.float32 else tmp19 < 0 + tmp23 = (libdevice.signbit(tmp17) != 0) if (tmp17).dtype is tl.float32 else tmp17 < 0 + tmp24 = tmp22 != tmp23 + tmp25 = tmp21 & tmp24 + tmp26 = tmp19 + tmp17 + tmp27 = tl.where(tmp25, tmp26, tmp19) + tmp28 = tmp27 < tmp10 + tmp29 = tmp18 & tmp28 + tmp30 = r0_3 + ((-1)*r0_4) + ((-128)*x1) + 128*x0 + tmp31 = (tmp30 % tmp17) + tmp32 = tmp31 != tmp20 + tmp33 = (libdevice.signbit(tmp31) != 0) if (tmp31).dtype is tl.float32 else tmp31 < 0 + tmp34 = tmp33 != tmp23 + tmp35 = tmp32 & tmp34 + tmp36 = tmp31 + tmp17 + tmp37 = tl.where(tmp35, tmp36, tmp31) + tmp38 = tl.full([1, 1], 0, tl.int64) + tmp39 = tmp37 == tmp38 + tmp40 = tmp29 & tmp39 + tmp41 = tmp16 | tmp40 + tmp42 = tl.full(tmp41.shape, False, tmp41.dtype) + tmp43 = tl.where(tmp6, tmp41, tmp42) + tmp44 = tmp43.to(tl.int64) + tmp45 = tl.broadcast_to(tmp44, [XBLOCK, R0_BLOCK]) + tmp47 = _tmp46 + tmp45 + _tmp46 = tl.where(r0_mask & xmask, tmp47, _tmp46) + tmp46 = tl.sum(_tmp46, 1)[:, None] + tmp48 = tl.full([1, 1], 0, tl.int64) + tmp49 = tmp46 > tmp48 + tmp50 = tl.full([1, 1], 16384, tl.int64) + tmp51 = tmp46 < tmp50 + tmp52 = tmp49 & tmp51 + tmp53 = tmp52.to(tl.int8) + tmp54 = tmp53.to(tl.int32) + tmp55 = tmp46 == tmp50 + tmp56 = tmp55.to(tl.int8) + tmp57 = tmp56.to(tl.int32) + tl.store(out_ptr1 + (x5), tmp54, xmask) + tl.store(out_ptr2 + (x5), tmp57, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/c3/59c9853fb2c80186a32ab90888a18bc19d8d4be350d304e62760e7bdf3f09c29.best_config b/SpecForge-ext/cache/compiled_kernels/c3/59c9853fb2c80186a32ab90888a18bc19d8d4be350d304e62760e7bdf3f09c29.best_config new file mode 100644 index 0000000000000000000000000000000000000000..c1c51c5048e176f0cf0b0d2646bd98c4186a3cba --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/c3/59c9853fb2c80186a32ab90888a18bc19d8d4be350d304e62760e7bdf3f09c29.best_config @@ -0,0 +1 @@ +{"XBLOCK": 8, "R0_BLOCK": 128, "num_warps": 8, "num_stages": 1, "configs_hash": "48464ea7d171263ae4fed5184e32a30841f1081b8df295ec1f8e2f76e5287c9d", "found_by_coordesc": false, "time_taken_ms": 60, "triton_cache_hash": "EGDJYO36DUYGK3UQBUH6S7RMVKF77GGHWVMFFZR5R4TDMIZ4YVJA"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/c3/cc3guwnwiox3yzzjtaquh6k4sm6nn4lcmkep56rop3grqr44xorh.py b/SpecForge-ext/cache/compiled_kernels/c3/cc3guwnwiox3yzzjtaquh6k4sm6nn4lcmkep56rop3grqr44xorh.py new file mode 100644 index 0000000000000000000000000000000000000000..21874eda12f1c0cd74f919ec866addc0d45c23ca --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/c3/cc3guwnwiox3yzzjtaquh6k4sm6nn4lcmkep56rop3grqr44xorh.py @@ -0,0 +1,49 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 131072, 'r0_': 128}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr1': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 1048576, 'r0_': 67108864}} +) +@triton.jit +def triton_red_fused_zeros_0(in_ptr0, in_ptr1, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 131072 + r0_numel = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % 2048) + x1 = ((xindex // 2048) % 32) + x2 = xindex // 65536 + x4 = xindex + _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_3 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_3 + 128*x1 + 4096*x0 + 8388608*x2), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_3 + 128*x4), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = _tmp4 + tmp3 + _tmp4 = tl.where(r0_mask, tmp5, _tmp4) + tmp4 = tl.sum(_tmp4, 1)[:, None] + tmp6 = tmp4.to(tl.float32) + tmp7 = 0.0 + tmp8 = tmp6 - tmp7 + tl.store(out_ptr1 + (x4), tmp8, None) diff --git a/SpecForge-ext/cache/compiled_kernels/c5/cc5dyv2gy7kqwwgof22mbw3houj3mwz3mpm5wwkls5nzlyig75gr.py b/SpecForge-ext/cache/compiled_kernels/c5/cc5dyv2gy7kqwwgof22mbw3houj3mwz3mpm5wwkls5nzlyig75gr.py new file mode 100644 index 0000000000000000000000000000000000000000..00d3f3d1a498615b0ffe1dc21b915218533f31a3 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/c5/cc5dyv2gy7kqwwgof22mbw3houj3mwz3mpm5wwkls5nzlyig75gr.py @@ -0,0 +1,63 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 16384, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr2': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'add_persistent_rblock': True, 'tiling_scores': {'x': 0, 'r0_': 5242880000}} +) +@triton.jit +def triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0(in_ptr0, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 16384 + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp3_max = tl.full([XBLOCK, R0_BLOCK], float('-inf'), tl.float32) + _tmp3_sum = tl.zeros([XBLOCK, R0_BLOCK], tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp1 = tmp0.to(tl.float32) + tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) + + _tmp3_max_next, _tmp3_sum_next = triton_helpers.online_softmax_combine( + _tmp3_max, _tmp3_sum, tmp2, False + ) + + _tmp3_max = tl.where(r0_mask, _tmp3_max_next, _tmp3_max) + _tmp3_sum = tl.where(r0_mask, _tmp3_sum_next, _tmp3_sum) + + tmp3, tmp4 = triton_helpers.online_softmax_reduce( + _tmp3_max, _tmp3_sum, 1, False) + tmp3 = tmp3[:, None] + tmp4 = tmp4[:, None] + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp5 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp6 = tmp5.to(tl.float32) + tmp7 = tmp6 - tmp3 + tmp8 = libdevice.exp(tmp7) + tmp9 = (tmp8 / tmp4) + tl.store(out_ptr2 + (r0_1 + 32000*x0), tmp9, r0_mask) diff --git a/SpecForge-ext/cache/compiled_kernels/c5/cc5nxeacvesgy4tcd55ftp6q6lwefzljen5getfvx2iwvse24nc3.py b/SpecForge-ext/cache/compiled_kernels/c5/cc5nxeacvesgy4tcd55ftp6q6lwefzljen5getfvx2iwvse24nc3.py new file mode 100644 index 0000000000000000000000000000000000000000..361a0c4b47d0acf78d1aaf4114cb08b4e8707d40 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/c5/cc5nxeacvesgy4tcd55ftp6q6lwefzljen5getfvx2iwvse24nc3.py @@ -0,0 +1,835 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention_backward(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 2097152, 262144, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 2097152, 262144, 128, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 8388608, 262144, 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 8388608, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 2097152, 262144, 128, 1 + + ZQ = 2 + HQ = 32 + HKV = 8 + Q_LEN = 2048 + ZKV = 2 + KV_LEN = 2048 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 2 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 256 + stride_kv_idx_m = 16 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 16 + stride_q_idx_h = 256 + stride_q_idx_n = 16 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 262144*off_hkv + 2097152*off_zq + tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = 2048 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr16 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp39 = (ds) + grad_scores = tmp39 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = 2048 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp40 = (qkT) + post_mod_scores = tmp40 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp41 = tl.full([1], False, tl.int1) + tmp42 = (m) + tmp43 = (n) + tmp44 = tmp42 >= tmp43 + tmp45 = tmp43.to(tl.int64) + tmp46 = (off_z) + tmp47 = tl.load(in_ptr16 + tmp46) + tmp48 = tmp45 < tmp47 + tmp49 = tmp42.to(tl.int64) + tmp50 = tmp49 < tmp47 + tmp51 = tmp48 & tmp50 + tmp52 = tmp44 & tmp51 + tmp53 = tmp41 | tmp52 + tmp54 = tl.full([1], 2048, tl.int32) + tmp55 = tmp43 >= tmp54 + tmp56 = (tmp43 % tmp54) + tmp57 = tl.full([1], 0, tl.int32) + tmp58 = tmp56 != tmp57 + tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0 + tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0 + tmp61 = tmp59 != tmp60 + tmp62 = tmp58 & tmp61 + tmp63 = tmp56 + tmp54 + tmp64 = tl.where(tmp62, tmp63, tmp56) + tmp65 = tmp64.to(tl.int64) + tmp66 = tmp65 < tmp47 + tmp67 = tmp55 & tmp66 + tmp68 = tmp43 - tmp42 + tmp69 = (tmp68 % tmp54) + tmp70 = tmp69 != tmp57 + tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0 + tmp72 = tmp71 != tmp60 + tmp73 = tmp70 & tmp72 + tmp74 = tmp69 + tmp54 + tmp75 = tl.where(tmp73, tmp74, tmp69) + tmp76 = tmp75 == tmp57 + tmp77 = tmp67 & tmp76 + tmp78 = tmp53 | tmp77 + mask_mod_output = tmp78 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp79 = (dsT) + grad_scores = tmp79 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/df/22a765d91676f959d9b3685b36d65a55ec25fa54238d5b7f1ad670a67b2ad8b4.best_config b/SpecForge-ext/cache/compiled_kernels/df/22a765d91676f959d9b3685b36d65a55ec25fa54238d5b7f1ad670a67b2ad8b4.best_config new file mode 100644 index 0000000000000000000000000000000000000000..37707241555f35a01f7e4a693e0cda27ae37aab0 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/df/22a765d91676f959d9b3685b36d65a55ec25fa54238d5b7f1ad670a67b2ad8b4.best_config @@ -0,0 +1 @@ +{"XBLOCK": 512, "num_warps": 8, "num_stages": 1, "configs_hash": "3ca5c3e34d35093f3c9ab2829a9faeebad5e61c4ca13d5ed6053d7b71ce60d5a", "found_by_coordesc": false, "time_taken_ms": 22, "triton_cache_hash": "XRR2QXTZQK4DSBTDJUTNXO6FEFXI2IIRKSC5GYSBWLTL56SKI4WA"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/df/cdfb6cgenzsju5cqvy4244xh4xidniyeznvkubvdg2mg6d5oc6xt.py b/SpecForge-ext/cache/compiled_kernels/df/cdfb6cgenzsju5cqvy4244xh4xidniyeznvkubvdg2mg6d5oc6xt.py new file mode 100644 index 0000000000000000000000000000000000000000..28acc3fcd780d0a9d09c77841b0a4c83aae0b668 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/df/cdfb6cgenzsju5cqvy4244xh4xidniyeznvkubvdg2mg6d5oc6xt.py @@ -0,0 +1,66 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 4194304}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 6, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = (xindex % ks0) + x3 = xindex + x1 = ((xindex // ks0) % ks1) + tmp31 = tl.load(in_ptr0 + (x3), xmask, eviction_policy='evict_last').to(tl.float32) + tmp32 = tl.load(in_ptr1 + (x1), xmask, eviction_policy='evict_last') + tmp0 = x0 + tmp1 = ks0 // 2 + tmp2 = tmp0 >= tmp1 + tmp3 = tl.load(in_ptr0 + (x3 + (-1)*(ks0 // 2)), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp4 = tl.load(in_ptr1 + (x1), tmp2 & xmask, eviction_policy='evict_last', other=0.0) + tmp5 = tl.broadcast_to(ks2, [XBLOCK]) + tmp6 = tmp4 + tmp5 + tmp7 = tmp4 < 0 + tmp8 = tl.where(tmp7, tmp6, tmp4) + tl.device_assert(((0 <= tl.broadcast_to(tmp8, [XBLOCK])) & (tl.broadcast_to(tmp8, [XBLOCK]) < ks2)) | ~(tmp2 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp8, [XBLOCK]) < ks2") + tmp10 = tl.load(in_ptr2 + (x0 + (-1)*(ks0 // 2) + ks0*tmp8), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp11 = tmp3 * tmp10 + tmp12 = -tmp11 + tmp13 = tl.full(tmp12.shape, 0.0, tmp12.dtype) + tmp14 = tl.where(tmp2, tmp12, tmp13) + tmp15 = 0.0 + tmp16 = tl.where(tmp2, tmp14, tmp15) + tmp17 = tmp0 < tmp1 + tmp18 = tl.load(in_ptr0 + (ks0 + x3 + (-1)*(ks0 // 2)), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp19 = tl.load(in_ptr1 + (x1), tmp17 & xmask, eviction_policy='evict_last', other=0.0) + tmp20 = tl.broadcast_to(ks2, [XBLOCK]) + tmp21 = tmp19 + tmp20 + tmp22 = tmp19 < 0 + tmp23 = tl.where(tmp22, tmp21, tmp19) + tl.device_assert(((0 <= tl.broadcast_to(tmp23, [XBLOCK])) & (tl.broadcast_to(tmp23, [XBLOCK]) < ks2)) | ~(tmp17 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp23, [XBLOCK]) < ks2") + tmp25 = tl.load(in_ptr2 + (ks0 + x0 + (-1)*(ks0 // 2) + ks0*tmp23), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp26 = tmp18 * tmp25 + tmp27 = tl.full(tmp26.shape, 0.0, tmp26.dtype) + tmp28 = tl.where(tmp17, tmp26, tmp27) + tmp29 = tl.where(tmp17, tmp28, tmp15) + tmp30 = tmp16 + tmp29 + tmp33 = ks3 + tmp34 = tmp32 + tmp33 + tmp35 = tmp32 < 0 + tmp36 = tl.where(tmp35, tmp34, tmp32) + tl.device_assert(((0 <= tmp36) & (tmp36 < ks3)) | ~(xmask), "index out of bounds: 0 <= tmp36 < ks3") + tmp38 = tl.load(in_ptr3 + (x0 + ks0*tmp36), xmask, eviction_policy='evict_last').to(tl.float32) + tmp39 = tmp31 * tmp38 + tmp40 = tmp30 + tmp39 + tl.store(out_ptr0 + (x3), tmp40, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/dh/cdhkxlmnw3w37dquqijlbxyhjukeitqpvuq2bvripdw6e7kvh63r.py b/SpecForge-ext/cache/compiled_kernels/dh/cdhkxlmnw3w37dquqijlbxyhjukeitqpvuq2bvripdw6e7kvh63r.py new file mode 100644 index 0000000000000000000000000000000000000000..b9995c25534b4b763f2489caf941aaaf9419dc40 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/dh/cdhkxlmnw3w37dquqijlbxyhjukeitqpvuq2bvripdw6e7kvh63r.py @@ -0,0 +1,309 @@ +# AOT ID: ['4_forward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/vy/cvyoqg7jzeadarrgggxhs2djmxugxyvjhim72vfstqc4io4qsecj.py +# Topologically Sorted Source Nodes: [squeeze, cos, squeeze_2, sin, getitem, cos_1, getitem_1, sin_1, mul, x1, x2, neg, cat, mul_1, q_embed], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.cat, aten.add] +# Source node to ATen node mapping: +# cat => cat +# cos => squeeze_1 +# cos_1 => unsqueeze +# getitem => index +# getitem_1 => index_1 +# mul => mul_24 +# mul_1 => mul_45 +# neg => neg +# q_embed => add_54 +# sin => squeeze_3 +# sin_1 => unsqueeze_1 +# squeeze => squeeze +# squeeze_2 => squeeze_2 +# x1 => slice_1 +# x2 => slice_2 +# Graph fragment: +# %primals_12 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24, s24*s34, 1]cuda:3" = PlaceHolder[target=primals_12] +# %primals_8 : Tensor "i64[1, s9][s9, 1]cuda:3" = PlaceHolder[target=primals_8] +# %primals_4 : Tensor "bf16[1, 1, s92, s24][s96, s96, s24, 1]cuda:3" = PlaceHolder[target=primals_4] +# %primals_6 : Tensor "bf16[1, 1, s79, s24][s96, s96, s24, 1]cuda:3" = PlaceHolder[target=primals_6] +# %squeeze : Tensor "bf16[1, s92, s24][s96, s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_4, 1), kwargs = {}) +# %squeeze_1 : Tensor "bf16[s92, s24][s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze, 0), kwargs = {}) +# %squeeze_2 : Tensor "bf16[1, s79, s24][s96, s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_6, 1), kwargs = {}) +# %squeeze_3 : Tensor "bf16[s79, s24][s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze_2, 0), kwargs = {}) +# %index : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_1, [%primals_8]), kwargs = {}) +# %unsqueeze : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:3"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index, 1), kwargs = {}) +# %index_1 : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_3, [%primals_8]), kwargs = {}) +# %unsqueeze_1 : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:3"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index_1, 1), kwargs = {}) +# %mul_24 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24, s24*s34, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%primals_12, %unsqueeze), kwargs = {}) +# %slice_1 : Tensor "bf16[s48, s34, s9, (s24//2)][s24*s34*s9, s24, s24*s34, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%primals_12, 3, 0, %floordiv), kwargs = {}) +# %slice_2 : Tensor "bf16[s48, s34, s9, s24 - ((s24//2))][s24*s34*s9, s24, s24*s34, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%primals_12, 3, %floordiv, 9223372036854775807), kwargs = {}) +# %neg : Tensor "bf16[s48, s34, s9, s24 - ((s24//2))][s34*s9*Max(1, s24 - ((s24//2))), Max(1, s24 - ((s24//2))), s34*Max(1, s24 - ((s24//2))), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%slice_2,), kwargs = {}) +# %cat : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg, %slice_1], -1), kwargs = {}) +# %mul_45 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat, %unsqueeze_1), kwargs = {}) +# %add_54 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24, s24*s34, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_24, %mul_45), kwargs = {}) +# return %add_54 +triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0 = async_compile.triton('triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 16777216}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'ks4': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, ks4, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x4 = xindex + x2 = ((xindex // ks0) % ks1) + x0 = (xindex % ks3) + x5 = xindex // ks3 + tmp0 = tl.load(in_ptr0 + (x4), xmask, eviction_policy='evict_last').to(tl.float32) + tmp1 = tl.load(in_ptr1 + (x2), xmask, eviction_policy='evict_last') + tmp2 = ks2 + tmp3 = tmp1 + tmp2 + tmp4 = tmp1 < 0 + tmp5 = tl.where(tmp4, tmp3, tmp1) + tl.device_assert(((0 <= tmp5) & (tmp5 < ks2)) | ~(xmask), "index out of bounds: 0 <= tmp5 < ks2") + tmp7 = tl.load(in_ptr2 + (x0 + ks3*tmp5), xmask, eviction_policy='evict_last').to(tl.float32) + tmp8 = tmp0 * tmp7 + tmp9 = x0 + tmp10 = tl.full([1], 0, tl.int64) + tmp11 = tmp9 >= tmp10 + tmp12 = ks3 + (-1)*(ks3 // 2) + tmp13 = tmp9 < tmp12 + tmp14 = tl.load(in_ptr0 + (ks3*x5 + (ks3 // 2) + (x0)), tmp13 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp15 = -tmp14 + tmp16 = tl.full(tmp15.shape, 0.0, tmp15.dtype) + tmp17 = tl.where(tmp13, tmp15, tmp16) + tmp18 = tmp9 >= tmp12 + tmp19 = ks3 + tmp20 = tmp9 < tmp19 + tmp21 = tl.load(in_ptr0 + (ks3*x5 + (x0 + ((-1)*ks3) + (ks3 // 2))), tmp18 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp22 = tl.where(tmp13, tmp17, tmp21) + tmp23 = ks4 + tmp24 = tmp1 + tmp23 + tmp25 = tl.where(tmp4, tmp24, tmp1) + tl.device_assert(((0 <= tmp25) & (tmp25 < ks4)) | ~(xmask), "index out of bounds: 0 <= tmp25 < ks4") + tmp27 = tl.load(in_ptr3 + (x0 + ks3*tmp25), xmask, eviction_policy='evict_last').to(tl.float32) + tmp28 = tmp22 * tmp27 + tmp29 = tmp8 + tmp28 + tl.store(out_ptr0 + (x4), tmp29, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wf/cwf6bkjfjzdctzdez7j3aj4ebefwcsqlz4gci5drec5pqysdr7fn.py +# Topologically Sorted Source Nodes: [squeeze, cos, squeeze_2, sin, getitem, cos_1, getitem_1, sin_1, mul_2, x1_1, x2_1, neg_1, cat_1, mul_3, k_embed], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.cat, aten.add] +# Source node to ATen node mapping: +# cat_1 => cat_1 +# cos => squeeze_1 +# cos_1 => unsqueeze +# getitem => index +# getitem_1 => index_1 +# k_embed => add_90 +# mul_2 => mul_54 +# mul_3 => mul_75 +# neg_1 => neg_1 +# sin => squeeze_3 +# sin_1 => unsqueeze_1 +# squeeze => squeeze +# squeeze_2 => squeeze_2 +# x1_1 => slice_3 +# x2_1 => slice_4 +# Graph fragment: +# %primals_14 : Tensor "bf16[s48, s25, s9, s24][s24*s25*s9, s24, s24*s25, 1]cuda:3" = PlaceHolder[target=primals_14] +# %primals_8 : Tensor "i64[1, s9][s9, 1]cuda:3" = PlaceHolder[target=primals_8] +# %primals_4 : Tensor "bf16[1, 1, s92, s24][s96, s96, s24, 1]cuda:3" = PlaceHolder[target=primals_4] +# %primals_6 : Tensor "bf16[1, 1, s79, s24][s96, s96, s24, 1]cuda:3" = PlaceHolder[target=primals_6] +# %squeeze : Tensor "bf16[1, s92, s24][s96, s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_4, 1), kwargs = {}) +# %squeeze_1 : Tensor "bf16[s92, s24][s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze, 0), kwargs = {}) +# %squeeze_2 : Tensor "bf16[1, s79, s24][s96, s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_6, 1), kwargs = {}) +# %squeeze_3 : Tensor "bf16[s79, s24][s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze_2, 0), kwargs = {}) +# %index : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_1, [%primals_8]), kwargs = {}) +# %unsqueeze : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:3"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index, 1), kwargs = {}) +# %index_1 : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_3, [%primals_8]), kwargs = {}) +# %unsqueeze_1 : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:3"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index_1, 1), kwargs = {}) +# %mul_54 : Tensor "bf16[s48, s25, s9, s24][s24*s25*s9, s24, s24*s25, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%primals_14, %unsqueeze), kwargs = {}) +# %slice_3 : Tensor "bf16[s48, s25, s9, (s24//2)][s24*s25*s9, s24, s24*s25, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%primals_14, 3, 0, %floordiv), kwargs = {}) +# %slice_4 : Tensor "bf16[s48, s25, s9, s24 - ((s24//2))][s24*s25*s9, s24, s24*s25, 1]cuda:3"[num_users=2] = call_function[target=torch.ops.aten.slice.Tensor](args = (%primals_14, 3, %floordiv, 9223372036854775807), kwargs = {}) +# %neg_1 : Tensor "bf16[s48, s25, s9, s24 - ((s24//2))][s25*s9*Max(1, s24 - ((s24//2))), Max(1, s24 - ((s24//2))), s25*Max(1, s24 - ((s24//2))), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%slice_4,), kwargs = {}) +# %cat_1 : Tensor "bf16[s48, s25, s9, s24][s24*s25*s9, s24*s9, s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg_1, %slice_3], -1), kwargs = {}) +# %mul_75 : Tensor "bf16[s48, s25, s9, s24][s24*s25*s9, s24*s9, s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_1, %unsqueeze_1), kwargs = {}) +# %add_90 : Tensor "bf16[s48, s25, s9, s24][s24*s25*s9, s24, s24*s25, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_54, %mul_75), kwargs = {}) +# return %add_90 +triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1 = async_compile.triton('triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 4194304}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'ks4': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, ks4, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x4 = xindex + x2 = ((xindex // ks0) % ks1) + x0 = (xindex % ks3) + x5 = xindex // ks3 + tmp0 = tl.load(in_ptr0 + (x4), xmask, eviction_policy='evict_last').to(tl.float32) + tmp1 = tl.load(in_ptr1 + (x2), xmask, eviction_policy='evict_last') + tmp2 = ks2 + tmp3 = tmp1 + tmp2 + tmp4 = tmp1 < 0 + tmp5 = tl.where(tmp4, tmp3, tmp1) + tl.device_assert(((0 <= tmp5) & (tmp5 < ks2)) | ~(xmask), "index out of bounds: 0 <= tmp5 < ks2") + tmp7 = tl.load(in_ptr2 + (x0 + ks3*tmp5), xmask, eviction_policy='evict_last').to(tl.float32) + tmp8 = tmp0 * tmp7 + tmp9 = x0 + tmp10 = tl.full([1], 0, tl.int64) + tmp11 = tmp9 >= tmp10 + tmp12 = ks3 + (-1)*(ks3 // 2) + tmp13 = tmp9 < tmp12 + tmp14 = tl.load(in_ptr0 + (ks3*x5 + (ks3 // 2) + (x0)), tmp13 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp15 = -tmp14 + tmp16 = tl.full(tmp15.shape, 0.0, tmp15.dtype) + tmp17 = tl.where(tmp13, tmp15, tmp16) + tmp18 = tmp9 >= tmp12 + tmp19 = ks3 + tmp20 = tmp9 < tmp19 + tmp21 = tl.load(in_ptr0 + (ks3*x5 + (x0 + ((-1)*ks3) + (ks3 // 2))), tmp18 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp22 = tl.where(tmp13, tmp17, tmp21) + tmp23 = ks4 + tmp24 = tmp1 + tmp23 + tmp25 = tl.where(tmp4, tmp24, tmp1) + tl.device_assert(((0 <= tmp25) & (tmp25 < ks4)) | ~(xmask), "index out of bounds: 0 <= tmp25 < ks4") + tmp27 = tl.load(in_ptr3 + (x0 + ks3*tmp25), xmask, eviction_policy='evict_last').to(tl.float32) + tmp28 = tmp22 * tmp27 + tmp29 = tmp8 + tmp28 + tl.store(out_ptr0 + (x4), tmp29, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14 = args + args.clear() + s92 = primals_1 + s24 = primals_2 + s96 = primals_3 + s79 = primals_5 + s9 = primals_7 + s38 = primals_9 + s48 = primals_10 + s34 = primals_11 + s25 = primals_13 + assert_size_stride(primals_4, (1, 1, s92, s24), (s96, s96, s24, 1)) + assert_size_stride(primals_6, (1, 1, s79, s24), (s96, s96, s24, 1)) + assert_size_stride(primals_8, (1, s9), (s9, 1)) + assert_size_stride(primals_12, (s48, s34, s9, s24), (s24*s34*s9, s24, s24*s34, 1)) + assert_size_stride(primals_14, (s48, s25, s9, s24), (s24*s25*s9, s24, s24*s25, 1)) + with torch.cuda._DeviceGuard(3): + torch.cuda.set_device(3) + ps0 = s24*s34 + buf0 = empty_strided_cuda((s48, s34, s9, s24), (s24*s34*s9, s24, s24*s34, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [squeeze, cos, squeeze_2, sin, getitem, cos_1, getitem_1, sin_1, mul, x1, x2, neg, cat, mul_1, q_embed], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.cat, aten.add] + triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0_xnumel = s24*s34*s48*s9 + stream3 = get_raw_stream(3) + triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0.run(primals_12, primals_8, primals_4, primals_6, buf0, ps0, s9, s92, s24, s79, triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0_xnumel, stream=stream3) + del primals_12 + ps1 = s24*s25 + buf1 = empty_strided_cuda((s48, s25, s9, s24), (s24*s25*s9, s24, s24*s25, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [squeeze, cos, squeeze_2, sin, getitem, cos_1, getitem_1, sin_1, mul_2, x1_1, x2_1, neg_1, cat_1, mul_3, k_embed], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.cat, aten.add] + triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1_xnumel = s24*s25*s48*s9 + stream3 = get_raw_stream(3) + triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1.run(primals_14, primals_8, primals_4, primals_6, buf1, ps1, s9, s92, s24, s79, triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1_xnumel, stream=stream3) + del primals_14 + return (buf0, buf1, primals_4, primals_6, primals_8, s24, s9, s48, s34, s25, s92, s96, s79, s24 // 2, s24 + (-1)*(s24 // 2), ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = 2048 + primals_2 = 128 + primals_3 = 5245440 + primals_4 = rand_strided((1, 1, 2048, 128), (5245440, 5245440, 128, 1), device='cuda:3', dtype=torch.bfloat16) + primals_5 = 2048 + primals_6 = rand_strided((1, 1, 2048, 128), (5245440, 5245440, 128, 1), device='cuda:3', dtype=torch.bfloat16) + primals_7 = 2048 + primals_8 = rand_strided((1, 2048), (2048, 1), device='cuda:3', dtype=torch.int64) + primals_9 = 1 + primals_10 = 2 + primals_11 = 32 + primals_12 = rand_strided((2, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:3', dtype=torch.bfloat16) + primals_13 = 8 + primals_14 = rand_strided((2, 8, 2048, 128), (2097152, 128, 1024, 1), device='cuda:3', dtype=torch.bfloat16) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/dk/cdkxpgcrmzd7osnpkgshsznwo3l2rl6sh5w7conrraiadtvt4dua.py b/SpecForge-ext/cache/compiled_kernels/dk/cdkxpgcrmzd7osnpkgshsznwo3l2rl6sh5w7conrraiadtvt4dua.py new file mode 100644 index 0000000000000000000000000000000000000000..53296e07f825b39b5b0fa1cc96b92dc6d966d279 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/dk/cdkxpgcrmzd7osnpkgshsznwo3l2rl6sh5w7conrraiadtvt4dua.py @@ -0,0 +1,835 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention_backward(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 2097152, 262144, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 2097152, 262144, 128, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 8388608, 262144, 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 8388608, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 2097152, 262144, 128, 1 + + ZQ = 2 + HQ = 32 + HKV = 8 + Q_LEN = 2048 + ZKV = 2 + KV_LEN = 2048 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 2 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 256 + stride_kv_idx_m = 16 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 16 + stride_q_idx_h = 256 + stride_q_idx_n = 16 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 262144*off_hkv + 2097152*off_zq + tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = 2048 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr16 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp39 = (ds) + grad_scores = tmp39 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = 2048 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp40 = (qkT) + post_mod_scores = tmp40 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp41 = tl.full([1], False, tl.int1) + tmp42 = (m) + tmp43 = (n) + tmp44 = tmp42 >= tmp43 + tmp45 = tmp43.to(tl.int64) + tmp46 = (off_z) + tmp47 = tl.load(in_ptr16 + tmp46) + tmp48 = tmp45 < tmp47 + tmp49 = tmp42.to(tl.int64) + tmp50 = tmp49 < tmp47 + tmp51 = tmp48 & tmp50 + tmp52 = tmp44 & tmp51 + tmp53 = tmp41 | tmp52 + tmp54 = tl.full([1], 2048, tl.int32) + tmp55 = tmp43 >= tmp54 + tmp56 = (tmp43 % tmp54) + tmp57 = tl.full([1], 0, tl.int32) + tmp58 = tmp56 != tmp57 + tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0 + tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0 + tmp61 = tmp59 != tmp60 + tmp62 = tmp58 & tmp61 + tmp63 = tmp56 + tmp54 + tmp64 = tl.where(tmp62, tmp63, tmp56) + tmp65 = tmp64.to(tl.int64) + tmp66 = tmp65 < tmp47 + tmp67 = tmp55 & tmp66 + tmp68 = tmp43 - tmp42 + tmp69 = (tmp68 % tmp54) + tmp70 = tmp69 != tmp57 + tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0 + tmp72 = tmp71 != tmp60 + tmp73 = tmp70 & tmp72 + tmp74 = tmp69 + tmp54 + tmp75 = tl.where(tmp73, tmp74, tmp69) + tmp76 = tmp75 == tmp57 + tmp77 = tmp67 & tmp76 + tmp78 = tmp53 | tmp77 + mask_mod_output = tmp78 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp79 = (dsT) + grad_scores = tmp79 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/dn/cdnzd5cdpynnx7cj5q3fipu4rishllkvp32nawgi2lth6637zisz.py b/SpecForge-ext/cache/compiled_kernels/dn/cdnzd5cdpynnx7cj5q3fipu4rishllkvp32nawgi2lth6637zisz.py new file mode 100644 index 0000000000000000000000000000000000000000..35e6c946cf7684e920a1ce92289d2cfcf1bda3db --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/dn/cdnzd5cdpynnx7cj5q3fipu4rishllkvp32nawgi2lth6637zisz.py @@ -0,0 +1,711 @@ +# AOT ID: ['13_forward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +from torch._C import _cuda_getCurrentRawStream as get_raw_stream +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/c4/cc44tmaxtaxohkbf52w5omwmrxhrmn6iuplipagv7rlnxaz6dkey.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %primals_2 : Tensor "bf16[8, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:7" = PlaceHolder[target=primals_2] +# %primals_4 : Tensor "bf16[8, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:7" = PlaceHolder[target=primals_4] +# %primals_6 : Tensor "bf16[8, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:7" = PlaceHolder[target=primals_6] +# %getitem_1 : Tensor "f32[8, 32, s37][32*s37, s37, 1]cuda:7" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[8, 32, s37][32*s37, s37, 1]cuda:7" = PlaceHolder[target=buf1] +# %primals_13 : Tensor "i32[8, 1, s99][s99, s99, 1]cuda:7" = PlaceHolder[target=primals_13] +# %primals_9 : Tensor "i32[8, 1, s22, s72][s22*s72, s22*s72, s72, 1]cuda:7" = PlaceHolder[target=primals_9] +# %primals_17 : Tensor "i32[8, 1, s94][s94, s94, 1]cuda:7" = PlaceHolder[target=primals_17] +# %primals_20 : Tensor "i32[8, 1, s28, s4][s28*s4, s28*s4, s4, 1]cuda:7" = PlaceHolder[target=primals_20] +# %primals_14 : Tensor "i64[8][1]cuda:7" = PlaceHolder[target=primals_14] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%primals_2, %primals_4, %primals_6, %sdpa_score0, (%primals_10, %primals_11, %primals_13, %primals_9, %primals_17, %primals_20, %primals_22, %primals_25, %primals_27, %primals_30, 128, 128, %sdpa_mask0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_14, %primals_15)), kwargs = {}) +# return %getitem +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'in_ptr9': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128*ks1, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128*ks1, 128, 1 + + ZQ = 8 + HQ = 32 + Q_LEN = ks0 + ZKV = 8 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 8 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m + 4096*idx_zq*ks0, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr9 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = ks5 + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21, primals_22, primals_23, primals_24, primals_25, primals_26, primals_27, primals_28, primals_29, primals_30 = args + args.clear() + s50 = primals_1 + s0 = primals_3 + s43 = primals_5 + s22 = primals_7 + s72 = primals_8 + s37 = primals_10 + s71 = primals_11 + s99 = primals_12 + s75 = primals_15 + s94 = primals_16 + s28 = primals_18 + s4 = primals_19 + s56 = primals_21 + s84 = primals_23 + s53 = primals_24 + s100 = primals_26 + s6 = primals_28 + s10 = primals_29 + assert_size_stride(primals_2, (8, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(primals_4, (8, 8, s0, 128), (1024*s0, 128*s0, 128, 1)) + assert_size_stride(primals_6, (8, 8, s0, 128), (1024*s0, 128*s0, 128, 1)) + assert_size_stride(primals_9, (8, 1, s22, s72), (s22*s72, s22*s72, s72, 1)) + assert_size_stride(primals_13, (8, 1, s99), (s99, s99, 1)) + assert_size_stride(primals_14, (8, ), (1, )) + assert_size_stride(primals_17, (8, 1, s94), (s94, s94, 1)) + assert_size_stride(primals_20, (8, 1, s28, s4), (s28*s4, s28*s4, s4, 1)) + assert_size_stride(primals_22, (8, 1, s56), (s56, s56, 1)) + assert_size_stride(primals_25, (8, 1, s84, s53), (s53*s84, s53*s84, s53, 1)) + assert_size_stride(primals_27, (8, 1, s100), (s100, s100, 1)) + assert_size_stride(primals_30, (8, 1, s6, s10), (s10*s6, s10*s6, s10, 1)) + with torch.cuda._DeviceGuard(7): + torch.cuda.set_device(7) + buf0 = empty_strided_cuda((8, 32, s37), (32*s37, s37, 1), torch.float32) + buf1 = empty_strided_cuda((8, 32, s37), (32*s37, s37, 1), torch.float32) + buf2 = empty_strided_cuda((8, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream7 = get_raw_stream(7) + triton_tem_fused_0.run(primals_2, primals_4, primals_6, buf0, buf1, primals_13, primals_9, primals_17, primals_20, primals_14, buf2, s37, s0, s99, s22, s72, s75, (127 + s37) // 128, 8, 32, stream=stream7) + del buf1 + return (buf2, primals_2, primals_4, primals_6, primals_9, primals_13, primals_14, primals_17, primals_20, primals_22, primals_25, primals_27, primals_30, buf2, buf0, s37, s0, s75, s22, s72, s99, s94, s28, s4, s56, s53, s84, s100, s10, s6, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = 2025 + primals_2 = rand_strided((8, 32, 2025, 128), (8294400, 128, 4096, 1), device='cuda:7', dtype=torch.bfloat16) + primals_3 = 2025 + primals_4 = rand_strided((8, 8, 2025, 128), (2073600, 259200, 128, 1), device='cuda:7', dtype=torch.bfloat16) + primals_5 = 2025 + primals_6 = rand_strided((8, 8, 2025, 128), (2073600, 259200, 128, 1), device='cuda:7', dtype=torch.bfloat16) + primals_7 = 16 + primals_8 = 16 + primals_9 = rand_strided((8, 1, 16, 16), (256, 256, 16, 1), device='cuda:7', dtype=torch.int32) + primals_10 = 2025 + primals_11 = 2025 + primals_12 = 16 + primals_13 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:7', dtype=torch.int32) + primals_14 = rand_strided((8, ), (1, ), device='cuda:7', dtype=torch.int64) + primals_15 = 2025 + primals_16 = 16 + primals_17 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:7', dtype=torch.int32) + primals_18 = 16 + primals_19 = 16 + primals_20 = rand_strided((8, 1, 16, 16), (256, 256, 16, 1), device='cuda:7', dtype=torch.int32) + primals_21 = 16 + primals_22 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:7', dtype=torch.int32) + primals_23 = 16 + primals_24 = 16 + primals_25 = rand_strided((8, 1, 16, 16), (256, 256, 16, 1), device='cuda:7', dtype=torch.int32) + primals_26 = 16 + primals_27 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:7', dtype=torch.int32) + primals_28 = 16 + primals_29 = 16 + primals_30 = rand_strided((8, 1, 16, 16), (256, 256, 16, 1), device='cuda:7', dtype=torch.int32) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21, primals_22, primals_23, primals_24, primals_25, primals_26, primals_27, primals_28, primals_29, primals_30]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/e2/ce2ywn5toqisdm3havxrhdb5mv52umindbb5uqyq5tsuulfpdsvn.py b/SpecForge-ext/cache/compiled_kernels/e2/ce2ywn5toqisdm3havxrhdb5mv52umindbb5uqyq5tsuulfpdsvn.py new file mode 100644 index 0000000000000000000000000000000000000000..1782b5bae318d5765abe6a93ba109f053601630a --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/e2/ce2ywn5toqisdm3havxrhdb5mv52umindbb5uqyq5tsuulfpdsvn.py @@ -0,0 +1,552 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'in_ptr9': '*i64', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 2097152, 262144, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 2097152, 262144, 128, 1 + + ZQ = 8 + HQ = 32 + Q_LEN = 2048 + ZKV = 8 + KV_LEN = 2048 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 8 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 256 + stride_kv_idx_m = 16 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 262144*idx_hq + 8388608*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m + 8388608*idx_zq, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr9 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/e5/338255b027f64b063c1dd4141ea4528d8c085950d66acacaef657908130c477a.best_config b/SpecForge-ext/cache/compiled_kernels/e5/338255b027f64b063c1dd4141ea4528d8c085950d66acacaef657908130c477a.best_config new file mode 100644 index 0000000000000000000000000000000000000000..b06045019d512d071043c60a2787e7e25be43ca7 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/e5/338255b027f64b063c1dd4141ea4528d8c085950d66acacaef657908130c477a.best_config @@ -0,0 +1 @@ +{"XBLOCK": 1, "R0_BLOCK": 2048, "num_warps": 16, "num_stages": 1, "configs_hash": "8c03dc2e05d158372838fe4d32248dfba74b467c7576f6e1d3eb472c41b37c80", "found_by_coordesc": false, "time_taken_ms": 198, "triton_cache_hash": "YHAVDQXMEVV7S4RZ3RZ2CHWHFBN2O3IAF5U3VLSP72AQHADF3BWQ"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/e5/ce5wvdfswd2gz4ndybxwt5ldwq4inxmypxads2mexg6jvncych4e.py b/SpecForge-ext/cache/compiled_kernels/e5/ce5wvdfswd2gz4ndybxwt5ldwq4inxmypxads2mexg6jvncych4e.py new file mode 100644 index 0000000000000000000000000000000000000000..69ce0a79abc74d7b132dc0ed3d1be3b0561c4e32 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/e5/ce5wvdfswd2gz4ndybxwt5ldwq4inxmypxads2mexg6jvncych4e.py @@ -0,0 +1,168 @@ +# AOT ID: ['10_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/mb/cmbhgz4c2hwbce6pchwlcnkxfrh55hxi5c4dp2sn4lys5xivdvad.py +# Topologically Sorted Source Nodes: [target_max_token, target_mask, getitem_1, target_mask_1, position_mask], Original ATen: [aten.argmax, aten.index, aten.unsqueeze, aten._to_copy, aten.mul] +# Source node to ATen node mapping: +# getitem_1 => unsqueeze +# position_mask => mul_2 +# target_mask => index +# target_mask_1 => convert_element_type +# target_max_token => argmax +# Graph fragment: +# %arg1_1 : Tensor "bf16[2, s14, 151936][151936*s14, 151936, 1]cuda:1" = PlaceHolder[target=arg1_1] +# %argmax : Tensor "i64[2, s14][s14, 1]cuda:1" = PlaceHolder[target=argmax] +# %arg2_1 : Tensor "b8[151936][1]cuda:1" = PlaceHolder[target=arg2_1] +# %arg3_1 : Tensor "i64[2, s14, 1][s14, 1, 1]cuda:1" = PlaceHolder[target=arg3_1] +# %argmax : Tensor "i64[2, s14][s14, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.argmax.default](args = (%arg1_1, -1), kwargs = {}) +# %index : Tensor "b8[2, s14][s14, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg2_1, [%argmax]), kwargs = {}) +# %unsqueeze : Tensor "b8[2, s14, 1][s14, 1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index, 2), kwargs = {}) +# %convert_element_type : Tensor "i32[2, s14, 1][s14, 1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%unsqueeze, torch.int32), kwargs = {}) +# %mul_2 : Tensor "i64[2, s14, 1][s14, 1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type, %arg3_1), kwargs = {}) +# return %argmax,%mul_2 +triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0 = async_compile.triton('triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 262144}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_out_ptr0': '*i64', 'in_ptr0': '*bf16', 'in_ptr1': '*i1', 'in_ptr2': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 151936 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 151936*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask & xmask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask & xmask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tmp11 = tl.load(in_ptr2 + (x0), xmask, eviction_policy='evict_last') + tmp3 = tl.full([XBLOCK, 1], 151936, tl.int32) + tmp4 = tmp2 + tmp3 + tmp5 = tmp2 < 0 + tmp6 = tl.where(tmp5, tmp4, tmp2) + tl.device_assert(((0 <= tmp6) & (tmp6 < 151936)) | ~(xmask), "index out of bounds: 0 <= tmp6 < 151936") + tmp8 = tl.load(in_ptr1 + (tmp6), xmask, eviction_policy='evict_last').to(tl.int1) + tmp9 = tmp8.to(tl.int32) + tmp10 = tmp9.to(tl.int64) + tmp12 = tmp10 * tmp11 + tl.debug_barrier() + tl.store(in_out_ptr0 + (x0), tmp12, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1, arg2_1, arg3_1 = args + args.clear() + s24 = arg0_1 + arg1_1_size = arg1_1.size() + s14 = arg1_1_size[1] + assert_size_stride(arg1_1, (2, s14, 151936), (151936*s14, 151936, 1)) + assert_size_stride(arg2_1, (151936, ), (1, )) + assert_size_stride(arg3_1, (2, s14, 1), (s14, 1, 1)) + with torch.cuda._DeviceGuard(1): + torch.cuda.set_device(1) + buf0 = empty_strided_cuda((2, s14), (s14, 1), torch.int64) + buf1 = reinterpret_tensor(buf0, (2, s14, 1), (s14, 1, 1), 0); del buf0 # reuse + # Topologically Sorted Source Nodes: [target_max_token, target_mask, getitem_1, target_mask_1, position_mask], Original ATen: [aten.argmax, aten.index, aten.unsqueeze, aten._to_copy, aten.mul] + triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0_xnumel = 2*s14 + stream1 = get_raw_stream(1) + triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0.run(buf1, arg1_1, arg2_1, arg3_1, triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0_xnumel, 151936, stream=stream1) + del arg1_1 + del arg2_1 + del arg3_1 + return (buf1, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = 1041 + arg1_1 = rand_strided((2, 1041, 151936), (158165376, 151936, 1), device='cuda:1', dtype=torch.bfloat16) + arg2_1 = rand_strided((151936, ), (1, ), device='cuda:1', dtype=torch.bool) + arg3_1 = rand_strided((2, 1041, 1), (1041, 1, 1), device='cuda:1', dtype=torch.int64) + fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/e7/4d7ba20defa03542a96d272a3e5f80d808a74a250a081e3ec006b26385f2b14d.best_config b/SpecForge-ext/cache/compiled_kernels/e7/4d7ba20defa03542a96d272a3e5f80d808a74a250a081e3ec006b26385f2b14d.best_config new file mode 100644 index 0000000000000000000000000000000000000000..7921a12b007ca46a00e959ad115401adf0bd4471 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/e7/4d7ba20defa03542a96d272a3e5f80d808a74a250a081e3ec006b26385f2b14d.best_config @@ -0,0 +1 @@ +{"XBLOCK": 128, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 11, "triton_cache_hash": "B46RWD5PEMKEQR7EBR6IG3BGTK4P7CWBVNOODNZQX5NAVXXVIH2A"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/e7/57bfd05f590389e28706411db3f110acfcb7ffd86f9ba40e21aeb6f925aebe0d.best_config b/SpecForge-ext/cache/compiled_kernels/e7/57bfd05f590389e28706411db3f110acfcb7ffd86f9ba40e21aeb6f925aebe0d.best_config new file mode 100644 index 0000000000000000000000000000000000000000..4fd1a4011d3f76e1e93d795a89d6bb6cdfb005a5 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/e7/57bfd05f590389e28706411db3f110acfcb7ffd86f9ba40e21aeb6f925aebe0d.best_config @@ -0,0 +1 @@ +{"XBLOCK": 512, "num_warps": 8, "num_stages": 1, "configs_hash": "3ca5c3e34d35093f3c9ab2829a9faeebad5e61c4ca13d5ed6053d7b71ce60d5a", "found_by_coordesc": false, "time_taken_ms": 57, "triton_cache_hash": "XRR2QXTZQK4DSBTDJUTNXO6FEFXI2IIRKSC5GYSBWLTL56SKI4WA"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/e7/ce7dupuihxxd773fvj7xxin22a343bj2cfpotahvltjkgeywvuif.py b/SpecForge-ext/cache/compiled_kernels/e7/ce7dupuihxxd773fvj7xxin22a343bj2cfpotahvltjkgeywvuif.py new file mode 100644 index 0000000000000000000000000000000000000000..cf6f123460f6ede2e3c91379c321e0c8b3c4e2ae --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/e7/ce7dupuihxxd773fvj7xxin22a343bj2cfpotahvltjkgeywvuif.py @@ -0,0 +1,66 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 16777216}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 6, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = (xindex % ks0) + x3 = xindex + x1 = ((xindex // ks0) % ks1) + tmp31 = tl.load(in_ptr0 + (x3), xmask, eviction_policy='evict_last').to(tl.float32) + tmp32 = tl.load(in_ptr1 + (x1), xmask, eviction_policy='evict_last') + tmp0 = x0 + tmp1 = ks0 // 2 + tmp2 = tmp0 >= tmp1 + tmp3 = tl.load(in_ptr0 + (x3 + (-1)*(ks0 // 2)), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp4 = tl.load(in_ptr1 + (x1), tmp2 & xmask, eviction_policy='evict_last', other=0.0) + tmp5 = tl.broadcast_to(ks2, [XBLOCK]) + tmp6 = tmp4 + tmp5 + tmp7 = tmp4 < 0 + tmp8 = tl.where(tmp7, tmp6, tmp4) + tl.device_assert(((0 <= tl.broadcast_to(tmp8, [XBLOCK])) & (tl.broadcast_to(tmp8, [XBLOCK]) < ks2)) | ~(tmp2 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp8, [XBLOCK]) < ks2") + tmp10 = tl.load(in_ptr2 + (x0 + (-1)*(ks0 // 2) + ks0*tmp8), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp11 = tmp3 * tmp10 + tmp12 = -tmp11 + tmp13 = tl.full(tmp12.shape, 0.0, tmp12.dtype) + tmp14 = tl.where(tmp2, tmp12, tmp13) + tmp15 = 0.0 + tmp16 = tl.where(tmp2, tmp14, tmp15) + tmp17 = tmp0 < tmp1 + tmp18 = tl.load(in_ptr0 + (ks0 + x3 + (-1)*(ks0 // 2)), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp19 = tl.load(in_ptr1 + (x1), tmp17 & xmask, eviction_policy='evict_last', other=0.0) + tmp20 = tl.broadcast_to(ks2, [XBLOCK]) + tmp21 = tmp19 + tmp20 + tmp22 = tmp19 < 0 + tmp23 = tl.where(tmp22, tmp21, tmp19) + tl.device_assert(((0 <= tl.broadcast_to(tmp23, [XBLOCK])) & (tl.broadcast_to(tmp23, [XBLOCK]) < ks2)) | ~(tmp17 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp23, [XBLOCK]) < ks2") + tmp25 = tl.load(in_ptr2 + (ks0 + x0 + (-1)*(ks0 // 2) + ks0*tmp23), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp26 = tmp18 * tmp25 + tmp27 = tl.full(tmp26.shape, 0.0, tmp26.dtype) + tmp28 = tl.where(tmp17, tmp26, tmp27) + tmp29 = tl.where(tmp17, tmp28, tmp15) + tmp30 = tmp16 + tmp29 + tmp33 = ks3 + tmp34 = tmp32 + tmp33 + tmp35 = tmp32 < 0 + tmp36 = tl.where(tmp35, tmp34, tmp32) + tl.device_assert(((0 <= tmp36) & (tmp36 < ks3)) | ~(xmask), "index out of bounds: 0 <= tmp36 < ks3") + tmp38 = tl.load(in_ptr3 + (x0 + ks0*tmp36), xmask, eviction_policy='evict_last').to(tl.float32) + tmp39 = tmp31 * tmp38 + tmp40 = tmp30 + tmp39 + tl.store(out_ptr0 + (x3), tmp40, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/e7/ce7e2b6aql2oficfu7lx6lr34nheb3k47fqvth4h6lbmsbvd64bk.py b/SpecForge-ext/cache/compiled_kernels/e7/ce7e2b6aql2oficfu7lx6lr34nheb3k47fqvth4h6lbmsbvd64bk.py new file mode 100644 index 0000000000000000000000000000000000000000..f45c707ada3e9be141f680941c0d9b77017ee2cf --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/e7/ce7e2b6aql2oficfu7lx6lr34nheb3k47fqvth4h6lbmsbvd64bk.py @@ -0,0 +1,1051 @@ +# AOT ID: ['6_backward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/c4/cc4r2l3x4dfli5iih5dji2abfxoclfozqdaqfbdxtcf6lqfpqwdo.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] +# Source node to ATen node mapping: +# Graph fragment: +# %getitem : Tensor "bf16[8, 32, 2048, 128][8388608, 128, 4096, 1]cuda:3" = PlaceHolder[target=getitem] +# %tangents_1 : Tensor "bf16[8, 32, 2048, 128][8388608, 262144, 128, 1]cuda:3" = PlaceHolder[target=tangents_1] +# %buf0 : Tensor "bf16[8, 32, 2048][65536, 2048, 1]cuda:3" = PlaceHolder[target=buf0] +# %full_default : Tensor "f32[8, 32, 2048][65536, 2048, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 32, 2048], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:3, pin_memory: False}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_1, %primals_2, %primals_3, %getitem, %getitem_1, %tangents_1, %full_default, %fw_graph0, %joint_graph0, (2048, 2048, %primals_5, %primals_4, %primals_7, %primals_8, %primals_9, %primals_10, %primals_11, %primals_12, 128, 128, %mask_graph0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_6,)), kwargs = {}) +# return %buf0,%buf1 +triton_red_fused_zeros_0 = async_compile.triton('triton_red_fused_zeros_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 524288, 'r0_': 128}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr1': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 4194304, 'r0_': 268435456}} +) +@triton.jit +def triton_red_fused_zeros_0(in_ptr0, in_ptr1, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 524288 + r0_numel = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % 2048) + x1 = ((xindex // 2048) % 32) + x2 = xindex // 65536 + x4 = xindex + _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_3 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_3 + 128*x1 + 4096*x0 + 8388608*x2), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_3 + 128*x4), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = _tmp4 + tmp3 + _tmp4 = tl.where(r0_mask, tmp5, _tmp4) + tmp4 = tl.sum(_tmp4, 1)[:, None] + tmp6 = tmp4.to(tl.float32) + tmp7 = 0.0 + tmp8 = tmp6 - tmp7 + tl.store(out_ptr1 + (x4), tmp8, None) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/cj/ccjdrw2g7vv7r7tckxv2il2dqbj2r4raz2ikafit5qmkqkg7j5jj.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] +# Source node to ATen node mapping: +# Graph fragment: +# %primals_1 : Tensor "bf16[8, 32, 2048, 128][8388608, 128, 4096, 1]cuda:3" = PlaceHolder[target=primals_1] +# %primals_2 : Tensor "bf16[8, 8, 2048, 128][2097152, 262144, 128, 1]cuda:3" = PlaceHolder[target=primals_2] +# %primals_3 : Tensor "bf16[8, 8, 2048, 128][2097152, 262144, 128, 1]cuda:3" = PlaceHolder[target=primals_3] +# %getitem_1 : Tensor "f32[8, 32, 2048][65536, 2048, 1]cuda:3" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[8, 32, 2048][65536, 2048, 1]cuda:3" = PlaceHolder[target=buf1] +# %tangents_1 : Tensor "bf16[8, 32, 2048, 128][8388608, 262144, 128, 1]cuda:3" = PlaceHolder[target=tangents_1] +# %getitem_3 : Tensor "bf16[8, 32, 2048, 128][8388608, 128, 4096, 1]cuda:3" = PlaceHolder[target=getitem_3] +# %getitem_5 : Tensor "bf16[8, 8, 2048, 128][2097152, 262144, 128, 1]cuda:3" = PlaceHolder[target=getitem_5] +# %primals_5 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:3" = PlaceHolder[target=primals_5] +# %primals_4 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:3" = PlaceHolder[target=primals_4] +# %primals_9 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:3" = PlaceHolder[target=primals_9] +# %primals_10 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:3" = PlaceHolder[target=primals_10] +# %primals_7 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:3" = PlaceHolder[target=primals_7] +# %primals_8 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:3" = PlaceHolder[target=primals_8] +# %primals_11 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:3" = PlaceHolder[target=primals_11] +# %primals_12 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:3" = PlaceHolder[target=primals_12] +# %primals_6 : Tensor "i64[8][1]cuda:3" = PlaceHolder[target=primals_6] +# %full_default : Tensor "f32[8, 32, 2048][65536, 2048, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 32, 2048], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:3, pin_memory: False}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_1, %primals_2, %primals_3, %getitem, %getitem_1, %tangents_1, %full_default, %fw_graph0, %joint_graph0, (2048, 2048, %primals_5, %primals_4, %primals_7, %primals_8, %primals_9, %primals_10, %primals_11, %primals_12, 128, 128, %mask_graph0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_6,)), kwargs = {}) +# return %getitem_4 +triton_tem_fused_zeros_1 = async_compile.triton('triton_tem_fused_zeros_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_zeros_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 2097152, 262144, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 2097152, 262144, 128, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 8388608, 262144, 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 8388608, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 2097152, 262144, 128, 1 + + ZQ = 8 + HQ = 32 + HKV = 8 + Q_LEN = 2048 + ZKV = 8 + KV_LEN = 2048 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 8 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 256 + stride_kv_idx_m = 16 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 16 + stride_q_idx_h = 256 + stride_q_idx_n = 16 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 262144*off_hkv + 2097152*off_zq + tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = 2048 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr16 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp39 = (ds) + grad_scores = tmp39 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = 2048 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp40 = (qkT) + post_mod_scores = tmp40 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp41 = tl.full([1], False, tl.int1) + tmp42 = (m) + tmp43 = (n) + tmp44 = tmp42 >= tmp43 + tmp45 = tmp43.to(tl.int64) + tmp46 = (off_z) + tmp47 = tl.load(in_ptr16 + tmp46) + tmp48 = tmp45 < tmp47 + tmp49 = tmp42.to(tl.int64) + tmp50 = tmp49 < tmp47 + tmp51 = tmp48 & tmp50 + tmp52 = tmp44 & tmp51 + tmp53 = tmp41 | tmp52 + tmp54 = tl.full([1], 2048, tl.int32) + tmp55 = tmp43 >= tmp54 + tmp56 = (tmp43 % tmp54) + tmp57 = tl.full([1], 0, tl.int32) + tmp58 = tmp56 != tmp57 + tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0 + tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0 + tmp61 = tmp59 != tmp60 + tmp62 = tmp58 & tmp61 + tmp63 = tmp56 + tmp54 + tmp64 = tl.where(tmp62, tmp63, tmp56) + tmp65 = tmp64.to(tl.int64) + tmp66 = tmp65 < tmp47 + tmp67 = tmp55 & tmp66 + tmp68 = tmp43 - tmp42 + tmp69 = (tmp68 % tmp54) + tmp70 = tmp69 != tmp57 + tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0 + tmp72 = tmp71 != tmp60 + tmp73 = tmp70 & tmp72 + tmp74 = tmp69 + tmp54 + tmp75 = tl.where(tmp73, tmp74, tmp69) + tmp76 = tmp75 == tmp57 + tmp77 = tmp67 & tmp76 + tmp78 = tmp53 | tmp77 + mask_mod_output = tmp78 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp79 = (dsT) + grad_scores = tmp79 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, getitem, getitem_1, tangents_1 = args + args.clear() + assert_size_stride(primals_1, (8, 32, 2048, 128), (8388608, 128, 4096, 1)) + assert_size_stride(primals_2, (8, 8, 2048, 128), (2097152, 262144, 128, 1)) + assert_size_stride(primals_3, (8, 8, 2048, 128), (2097152, 262144, 128, 1)) + assert_size_stride(primals_4, (8, 1, 16, 16), (256, 256, 16, 1)) + assert_size_stride(primals_5, (8, 1, 16), (16, 16, 1)) + assert_size_stride(primals_6, (8, ), (1, )) + assert_size_stride(primals_7, (8, 1, 16), (16, 16, 1)) + assert_size_stride(primals_8, (8, 1, 16, 16), (256, 256, 16, 1)) + assert_size_stride(primals_9, (8, 1, 16), (16, 16, 1)) + assert_size_stride(primals_10, (8, 1, 16, 16), (256, 256, 16, 1)) + assert_size_stride(primals_11, (8, 1, 16), (16, 16, 1)) + assert_size_stride(primals_12, (8, 1, 16, 16), (256, 256, 16, 1)) + assert_size_stride(getitem, (8, 32, 2048, 128), (8388608, 128, 4096, 1)) + assert_size_stride(getitem_1, (8, 32, 2048), (65536, 2048, 1)) + assert_size_stride(tangents_1, (8, 32, 2048, 128), (8388608, 262144, 128, 1)) + with torch.cuda._DeviceGuard(3): + torch.cuda.set_device(3) + buf1 = empty_strided_cuda((8, 32, 2048), (65536, 2048, 1), torch.float32) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] + stream3 = get_raw_stream(3) + triton_red_fused_zeros_0.run(getitem, tangents_1, buf1, 524288, 128, stream=stream3) + del getitem + buf3 = empty_strided_cuda((8, 32, 2048, 128), (8388608, 128, 4096, 1), torch.bfloat16) + buf4 = empty_strided_cuda((8, 8, 2048, 128), (2097152, 262144, 128, 1), torch.bfloat16) + buf5 = empty_strided_cuda((8, 8, 2048, 128), (2097152, 262144, 128, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] + stream3 = get_raw_stream(3) + triton_tem_fused_zeros_1.run(primals_1, primals_2, primals_3, getitem_1, buf1, tangents_1, buf3, buf4, primals_5, primals_4, primals_9, primals_10, primals_7, primals_8, primals_11, primals_12, primals_6, buf5, 80, 8, 8, stream=stream3) + del buf1 + del getitem_1 + del primals_1 + del primals_10 + del primals_11 + del primals_12 + del primals_2 + del primals_3 + del primals_4 + del primals_5 + del primals_6 + del primals_7 + del primals_8 + del primals_9 + del tangents_1 + return (buf3, buf5, buf4, None, None, None, None, None, None, None, None, None, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = rand_strided((8, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:3', dtype=torch.bfloat16) + primals_2 = rand_strided((8, 8, 2048, 128), (2097152, 262144, 128, 1), device='cuda:3', dtype=torch.bfloat16) + primals_3 = rand_strided((8, 8, 2048, 128), (2097152, 262144, 128, 1), device='cuda:3', dtype=torch.bfloat16) + primals_4 = rand_strided((8, 1, 16, 16), (256, 256, 16, 1), device='cuda:3', dtype=torch.int32) + primals_5 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:3', dtype=torch.int32) + primals_6 = rand_strided((8, ), (1, ), device='cuda:3', dtype=torch.int64) + primals_7 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:3', dtype=torch.int32) + primals_8 = rand_strided((8, 1, 16, 16), (256, 256, 16, 1), device='cuda:3', dtype=torch.int32) + primals_9 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:3', dtype=torch.int32) + primals_10 = rand_strided((8, 1, 16, 16), (256, 256, 16, 1), device='cuda:3', dtype=torch.int32) + primals_11 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:3', dtype=torch.int32) + primals_12 = rand_strided((8, 1, 16, 16), (256, 256, 16, 1), device='cuda:3', dtype=torch.int32) + getitem = rand_strided((8, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:3', dtype=torch.bfloat16) + getitem_1 = rand_strided((8, 32, 2048), (65536, 2048, 1), device='cuda:3', dtype=torch.float32) + tangents_1 = rand_strided((8, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:3', dtype=torch.bfloat16) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, getitem, getitem_1, tangents_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/e7/ce7i2tvzt4y76dl36jhidy7pbl6vcnmlxbjumk4msd6tfgpokqta.py b/SpecForge-ext/cache/compiled_kernels/e7/ce7i2tvzt4y76dl36jhidy7pbl6vcnmlxbjumk4msd6tfgpokqta.py new file mode 100644 index 0000000000000000000000000000000000000000..8e930ab5283c413deb0001c142af7ece1c8db091 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/e7/ce7i2tvzt4y76dl36jhidy7pbl6vcnmlxbjumk4msd6tfgpokqta.py @@ -0,0 +1,24 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 8192}, + filename=__file__, + triton_meta={'signature': {'out_ptr0': '*i32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_new_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_new_zeros_0(out_ptr0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.full([1], 0, tl.int32) + tl.store(out_ptr0 + (x0), tmp0, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/en/cenh5uz42ng4lj7xw7veh7qtahkm73nfwpjlgreomiruz4qp4l5j.py b/SpecForge-ext/cache/compiled_kernels/en/cenh5uz42ng4lj7xw7veh7qtahkm73nfwpjlgreomiruz4qp4l5j.py new file mode 100644 index 0000000000000000000000000000000000000000..b6d3931b03da428df61cfc3409dfc8e46e7a940d --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/en/cenh5uz42ng4lj7xw7veh7qtahkm73nfwpjlgreomiruz4qp4l5j.py @@ -0,0 +1,552 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'in_ptr9': '*i64', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 2097152, 262144, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 2097152, 262144, 128, 1 + + ZQ = 8 + HQ = 32 + Q_LEN = 2048 + ZKV = 8 + KV_LEN = 2048 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 8 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 256 + stride_kv_idx_m = 16 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 262144*idx_hq + 8388608*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m + 8388608*idx_zq, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr9 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/es/ces6igirh7ild5cxrl3jkv5ib25midu6s5yyh2tqsvbm3cwwomwg.py b/SpecForge-ext/cache/compiled_kernels/es/ces6igirh7ild5cxrl3jkv5ib25midu6s5yyh2tqsvbm3cwwomwg.py new file mode 100644 index 0000000000000000000000000000000000000000..2f2e9a823426b9d6b04017400bebe774f162c893 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/es/ces6igirh7ild5cxrl3jkv5ib25midu6s5yyh2tqsvbm3cwwomwg.py @@ -0,0 +1,835 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32', 'ks6': 'i32', 'ks7': 'i32', 'ks8': 'i32'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_zeros_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128*ks1, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128*ks1, 128, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128*ks1, 128, 1 + + ZQ = 2 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 2 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 2 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = ks5 + stride_q_idx_h = ks6*ks7 + stride_q_idx_n = ks6 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr16 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = ks8 + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp39 = (ds) + grad_scores = tmp39 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp40 = (qkT) + post_mod_scores = tmp40 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp41 = tl.full([1], False, tl.int1) + tmp42 = (m) + tmp43 = (n) + tmp44 = tmp42 >= tmp43 + tmp45 = tmp43.to(tl.int64) + tmp46 = (off_z) + tmp47 = tl.load(in_ptr16 + tmp46) + tmp48 = tmp45 < tmp47 + tmp49 = tmp42.to(tl.int64) + tmp50 = tmp49 < tmp47 + tmp51 = tmp48 & tmp50 + tmp52 = tmp44 & tmp51 + tmp53 = tmp41 | tmp52 + tmp54 = ks8 + tmp55 = tmp43 >= tmp54 + tmp56 = (tmp43 % tmp54) + tmp57 = tl.full([1], 0, tl.int32) + tmp58 = tmp56 != tmp57 + tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0 + tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0 + tmp61 = tmp59 != tmp60 + tmp62 = tmp58 & tmp61 + tmp63 = tmp56 + tmp54 + tmp64 = tl.where(tmp62, tmp63, tmp56) + tmp65 = tmp64.to(tl.int64) + tmp66 = tmp65 < tmp47 + tmp67 = tmp55 & tmp66 + tmp68 = tmp43 - tmp42 + tmp69 = (tmp68 % tmp54) + tmp70 = tmp69 != tmp57 + tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0 + tmp72 = tmp71 != tmp60 + tmp73 = tmp70 & tmp72 + tmp74 = tmp69 + tmp54 + tmp75 = tl.where(tmp73, tmp74, tmp69) + tmp76 = tmp75 == tmp57 + tmp77 = tmp67 & tmp76 + tmp78 = tmp53 | tmp77 + mask_mod_output = tmp78 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp79 = (dsT) + grad_scores = tmp79 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/es/cesbynbccl5kxgzk5frbtjtz7wks36wdnlmmdjvad5mpaslioisz.py b/SpecForge-ext/cache/compiled_kernels/es/cesbynbccl5kxgzk5frbtjtz7wks36wdnlmmdjvad5mpaslioisz.py new file mode 100644 index 0000000000000000000000000000000000000000..eb8391dca4ce59b4a76b7e9badafcb144ef80d58 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/es/cesbynbccl5kxgzk5frbtjtz7wks36wdnlmmdjvad5mpaslioisz.py @@ -0,0 +1,40 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 2048}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i32', 'out_ptr0': '*i32', 'out_ptr1': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3', 'mutated_arg_names': ['out_ptr1'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3(in_ptr0, in_ptr1, out_ptr0, out_ptr1, ks0, ks1, ks2, ks3, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = (xindex % ks0) + x1 = ((xindex // ks0) % ks1) + x2 = xindex // ks2 + x3 = xindex // ks0 + tmp0 = tl.load(in_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))) + x2*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1)))), xmask, eviction_policy='evict_last') + tmp2 = tl.load(in_ptr1 + (x3), xmask, eviction_policy='evict_last') + tmp1 = tmp0.to(tl.int32) + tmp3 = x0 + tmp4 = tmp3 < tmp2 + tmp5 = ks0 + tmp6 = tl.where(tmp4, tmp1, tmp5) + tmp7 = 1 + ks0 + tmp8 = tmp6 + tmp7 + tmp9 = tmp6 < 0 + tmp10 = tl.where(tmp9, tmp8, tmp6) + tl.device_assert(((0 <= tmp10) & (tmp10 < 1 + (triton_helpers.div_floor_integer(127 + ks3, 128)))) | ~(xmask), "index out of bounds: 0 <= tmp10 < 1 + (triton_helpers.div_floor_integer(127 + ks3, 128))") + tmp12 = tl.full([1], 1, tl.int32) + tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))) + x2*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1)))), tmp1, xmask) + tl.store(out_ptr1 + (tmp10 + x3 + ks0*x3), tmp12, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/ev/cevyod6oo6er6cp6u7pfouxiiixzspqvpc3i7tolk7lj6p5y6z4f.py b/SpecForge-ext/cache/compiled_kernels/ev/cevyod6oo6er6cp6u7pfouxiiixzspqvpc3i7tolk7lj6p5y6z4f.py new file mode 100644 index 0000000000000000000000000000000000000000..498f81c8d4af82e5597064731bf97ae1a9c7e1c3 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/ev/cevyod6oo6er6cp6u7pfouxiiixzspqvpc3i7tolk7lj6p5y6z4f.py @@ -0,0 +1,62 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 128, 'r0_': 32}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'in_ptr1': '*i64', 'out_ptr1': '*i32', 'out_ptr2': '*i32', 'out_ptr3': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2', 'mutated_arg_names': ['out_ptr3'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2(in_ptr0, in_ptr1, out_ptr1, out_ptr2, out_ptr3, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp3 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp1 = tmp0.to(tl.int64) + tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) + tmp4 = _tmp3 + tmp2 + _tmp3 = tl.where(r0_mask & xmask, tmp4, _tmp3) + tmp3 = tl.sum(_tmp3, 1)[:, None] + tmp5 = tmp3.to(tl.int32) + tl.store(out_ptr1 + (x0), tmp5, xmask) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp6 = tl.load(in_ptr1 + (r0_1 + x0*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp7 = tmp6.to(tl.int32) + tmp8 = r0_1 + tmp9 = tmp8 < tmp5 + tmp10 = ks0 + tmp11 = tl.where(tmp9, tmp7, tmp10) + tmp12 = 1 + ks0 + tmp13 = tmp11 + tmp12 + tmp14 = tmp11 < 0 + tmp15 = tl.where(tmp14, tmp13, tmp11) + tl.device_assert(((0 <= tmp15) & (tmp15 < 1 + (triton_helpers.div_floor_integer(127 + ks1, 128)))) | ~(r0_mask & xmask), "index out of bounds: 0 <= tmp15 < 1 + (triton_helpers.div_floor_integer(127 + ks1, 128))") + tmp17 = tl.full([1, 1], 1, tl.int32) + tl.store(out_ptr2 + (r0_1 + x0*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp7, r0_mask & xmask) + tl.store(out_ptr3 + (tl.broadcast_to(tmp15 + x0 + ks0*x0, [XBLOCK, R0_BLOCK])), tmp17, r0_mask & xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/f7/cf7fnr4ndtni3aqqgi66t2b7kqmt5fc2ap2if42cub6ljd7c5z7p.py b/SpecForge-ext/cache/compiled_kernels/f7/cf7fnr4ndtni3aqqgi66t2b7kqmt5fc2ap2if42cub6ljd7c5z7p.py new file mode 100644 index 0000000000000000000000000000000000000000..6d5ab52bfadcc433b2db0210788653c718dbc8ab --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/f7/cf7fnr4ndtni3aqqgi66t2b7kqmt5fc2ap2if42cub6ljd7c5z7p.py @@ -0,0 +1,354 @@ +# AOT ID: ['15_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uv/cuv3vrtusq2q2nsfbjisxnso2yv7wfwd5g5wzneucl6wsg7qdu22.py +# Topologically Sorted Source Nodes: [argmax], Original ATen: [aten.argmax] +# Source node to ATen node mapping: +# argmax => argmax +# Graph fragment: +# %arg1_1 : Tensor "bf16[2, s3, 32000][32000*s3, 32000, 1]cuda:5" = PlaceHolder[target=arg1_1] +# %argmax : Tensor "i64[2, s3][s3, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.argmax.default](args = (%arg1_1, -1), kwargs = {}) +# return %argmax +triton_red_fused_argmax_0 = async_compile.triton('triton_red_fused_argmax_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_argmax_0(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask & xmask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask & xmask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tl.store(out_ptr0 + (x0), tmp2, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ry/crybdkaospobmiqlxc56pxoib5eehua75mh3od3mgjz4754h54wu.py +# Topologically Sorted Source Nodes: [argmax_1], Original ATen: [aten.argmax] +# Source node to ATen node mapping: +# argmax_1 => argmax_1 +# Graph fragment: +# %arg4_1 : Tensor "f32[2, s3, 32000][s71, 32000, 1]cuda:5" = PlaceHolder[target=arg4_1] +# %argmax_1 : Tensor "i64[2, s3][s3, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.argmax.default](args = (%arg4_1, -1), kwargs = {}) +# return %argmax_1 +triton_red_fused_argmax_1 = async_compile.triton('triton_red_fused_argmax_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*i64', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(1,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_argmax_1(in_ptr0, out_ptr0, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % ks0) + x1 = xindex // ks0 + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + x3 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_2 + 32000*x0 + ks1*x1), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask & xmask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask & xmask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tl.store(out_ptr0 + (x3), tmp2, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/t6/ct65mho34hb6uiko5rqcirrhdoklxtx3wge4z7y6klsukdtxu23g.py +# Topologically Sorted Source Nodes: [eq, squeeze, mul, sum_1], Original ATen: [aten.eq, aten.squeeze, aten.mul, aten.sum] +# Source node to ATen node mapping: +# eq => eq_2 +# mul => mul_3 +# squeeze => squeeze +# sum_1 => sum_1 +# Graph fragment: +# %argmax : Tensor "i64[2, s3][s3, 1]cuda:5" = PlaceHolder[target=argmax] +# %argmax_1 : Tensor "i64[2, s3][s3, 1]cuda:5" = PlaceHolder[target=argmax_1] +# %arg5_1 : Tensor "i64[2, s3, 1][s3, 1, 1]cuda:5" = PlaceHolder[target=arg5_1] +# %eq_2 : Tensor "b8[2, s3][s3, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.eq.Tensor](args = (%argmax, %argmax_1), kwargs = {}) +# %squeeze : Tensor "i64[2, s3][s3, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%arg5_1, -1), kwargs = {}) +# %mul_3 : Tensor "i64[2, s3][s3, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%eq_2, %squeeze), kwargs = {}) +# %sum_1 : Tensor "i64[][]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_3,), kwargs = {}) +# return %sum_1 +triton_red_fused_eq_mul_squeeze_sum_2 = async_compile.triton('triton_red_fused_eq_mul_squeeze_sum_2', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 1, 'r0_': 4096}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i64', 'in_ptr2': '*i64', 'out_ptr0': '*i64', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {'xnumel': 1}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_eq_mul_squeeze_sum_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_eq_mul_squeeze_sum_2(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 1 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + _tmp7 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_0 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.load(in_ptr1 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp4 = tl.load(in_ptr2 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp2 = tmp0 == tmp1 + tmp3 = tmp2.to(tl.int64) + tmp5 = tmp3 * tmp4 + tmp6 = tl.broadcast_to(tmp5, [XBLOCK, R0_BLOCK]) + tmp8 = _tmp7 + tmp6 + _tmp7 = tl.where(r0_mask, tmp8, _tmp7) + tmp7 = tl.sum(_tmp7, 1)[:, None] + tl.store(out_ptr0 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp7, None) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/zn/czno5j5znzvn7ffbpdidli27z42gguatm457hdcxachttejmohyn.py +# Topologically Sorted Source Nodes: [sum_2, clamp_min, truediv], Original ATen: [aten.sum, aten.clamp_min, aten.div] +# Source node to ATen node mapping: +# clamp_min => clamp_min +# sum_2 => sum_2 +# truediv => div +# Graph fragment: +# %arg7_1 : Tensor "i64[2, s14, 1][s14, 1, 1]cuda:5" = PlaceHolder[target=arg7_1] +# %sum_1 : Tensor "i64[][]cuda:5" = PlaceHolder[target=sum_1] +# %sum_2 : Tensor "i64[][]cuda:5" = PlaceHolder[target=sum_2] +# %sum_2 : Tensor "i64[][]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%arg7_1,), kwargs = {}) +# %clamp_min : Tensor "f32[][]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.clamp_min.default](args = (%sum_2, 1e-06), kwargs = {}) +# %div : Tensor "f32[][]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%sum_1, %clamp_min), kwargs = {}) +# return %sum_2,%div +triton_red_fused_clamp_min_div_sum_3 = async_compile.triton('triton_red_fused_clamp_min_div_sum_3', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 1, 'r0_': 4096}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i64', 'out_ptr1': '*fp32', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {'xnumel': 1}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_clamp_min_div_sum_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_clamp_min_div_sum_3(in_ptr0, in_ptr1, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 1 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + _tmp2 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_0 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp3 = _tmp2 + tmp1 + _tmp2 = tl.where(r0_mask, tmp3, _tmp2) + tmp2 = tl.sum(_tmp2, 1)[:, None] + tmp4 = tl.load(in_ptr1 + (0)) + tmp5 = tl.broadcast_to(tmp4, [XBLOCK, 1]) + tmp6 = tmp5.to(tl.float32) + tmp7 = tmp2.to(tl.float32) + tmp8 = 1e-06 + tmp9 = triton_helpers.maximum(tmp7, tmp8) + tmp10 = (tmp6 / tmp9) + tl.store(out_ptr1 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp10, None) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1 = args + args.clear() + s3 = arg0_1 + s71 = arg2_1 + s0 = arg3_1 + s14 = arg6_1 + assert_size_stride(arg1_1, (2, s3, 32000), (32000*s3, 32000, 1)) + assert_size_stride(arg4_1, (2, s3, 32000), (s71, 32000, 1)) + assert_size_stride(arg5_1, (2, s3, 1), (s3, 1, 1)) + assert_size_stride(arg7_1, (2, s14, 1), (s14, 1, 1)) + with torch.cuda._DeviceGuard(5): + torch.cuda.set_device(5) + buf0 = empty_strided_cuda((2, s3), (s3, 1), torch.int64) + # Topologically Sorted Source Nodes: [argmax], Original ATen: [aten.argmax] + triton_red_fused_argmax_0_xnumel = 2*s3 + stream5 = get_raw_stream(5) + triton_red_fused_argmax_0.run(arg1_1, buf0, triton_red_fused_argmax_0_xnumel, 32000, stream=stream5) + del arg1_1 + buf1 = empty_strided_cuda((2, s3), (s3, 1), torch.int64) + # Topologically Sorted Source Nodes: [argmax_1], Original ATen: [aten.argmax] + triton_red_fused_argmax_1_xnumel = 2*s3 + stream5 = get_raw_stream(5) + triton_red_fused_argmax_1.run(arg4_1, buf1, s3, s71, triton_red_fused_argmax_1_xnumel, 32000, stream=stream5) + del arg4_1 + buf2 = empty_strided_cuda((), (), torch.int64) + # Topologically Sorted Source Nodes: [eq, squeeze, mul, sum_1], Original ATen: [aten.eq, aten.squeeze, aten.mul, aten.sum] + triton_red_fused_eq_mul_squeeze_sum_2_r0_numel = 2*s3 + stream5 = get_raw_stream(5) + triton_red_fused_eq_mul_squeeze_sum_2.run(buf0, buf1, arg5_1, buf2, 1, triton_red_fused_eq_mul_squeeze_sum_2_r0_numel, stream=stream5) + del arg5_1 + del buf0 + del buf1 + buf4 = empty_strided_cuda((), (), torch.float32) + # Topologically Sorted Source Nodes: [sum_2, clamp_min, truediv], Original ATen: [aten.sum, aten.clamp_min, aten.div] + triton_red_fused_clamp_min_div_sum_3_r0_numel = 2*s14 + stream5 = get_raw_stream(5) + triton_red_fused_clamp_min_div_sum_3.run(arg7_1, buf2, buf4, 1, triton_red_fused_clamp_min_div_sum_3_r0_numel, stream=stream5) + del arg7_1 + del buf2 + return (buf4, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = 1569 + arg1_1 = rand_strided((2, 1569, 32000), (50208000, 32000, 1), device='cuda:5', dtype=torch.bfloat16) + arg2_1 = 50432000 + arg3_1 = 32000 + arg4_1 = rand_strided((2, 1569, 32000), (50432000, 32000, 1), device='cuda:5', dtype=torch.float32) + arg5_1 = rand_strided((2, 1569, 1), (1569, 1, 1), device='cuda:5', dtype=torch.int64) + arg6_1 = 1569 + arg7_1 = rand_strided((2, 1569, 1), (1569, 1, 1), device='cuda:5', dtype=torch.int64) + fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/f7/cf7xv6jfcmlfiqcycxf4jqhyeqqseswqjzqs3pbgegd4gpxeqrsj.py b/SpecForge-ext/cache/compiled_kernels/f7/cf7xv6jfcmlfiqcycxf4jqhyeqqseswqjzqs3pbgegd4gpxeqrsj.py new file mode 100644 index 0000000000000000000000000000000000000000..4f240560eac4d7e75f557440f5393fc9d9804092 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/f7/cf7xv6jfcmlfiqcycxf4jqhyeqqseswqjzqs3pbgegd4gpxeqrsj.py @@ -0,0 +1,711 @@ +# AOT ID: ['13_forward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +from torch._C import _cuda_getCurrentRawStream as get_raw_stream +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %primals_2 : Tensor "bf16[8, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:0" = PlaceHolder[target=primals_2] +# %primals_4 : Tensor "bf16[8, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:0" = PlaceHolder[target=primals_4] +# %primals_6 : Tensor "bf16[8, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:0" = PlaceHolder[target=primals_6] +# %getitem_1 : Tensor "f32[8, 32, s37][32*s37, s37, 1]cuda:0" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[8, 32, s37][32*s37, s37, 1]cuda:0" = PlaceHolder[target=buf1] +# %primals_13 : Tensor "i32[8, 1, s99][s99, s99, 1]cuda:0" = PlaceHolder[target=primals_13] +# %primals_9 : Tensor "i32[8, 1, s22, s72][s22*s72, s22*s72, s72, 1]cuda:0" = PlaceHolder[target=primals_9] +# %primals_17 : Tensor "i32[8, 1, s94][s94, s94, 1]cuda:0" = PlaceHolder[target=primals_17] +# %primals_20 : Tensor "i32[8, 1, s28, s4][s28*s4, s28*s4, s4, 1]cuda:0" = PlaceHolder[target=primals_20] +# %primals_14 : Tensor "i64[8][1]cuda:0" = PlaceHolder[target=primals_14] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%primals_2, %primals_4, %primals_6, %sdpa_score0, (%primals_10, %primals_11, %primals_13, %primals_9, %primals_17, %primals_20, %primals_22, %primals_25, %primals_27, %primals_30, 128, 128, %sdpa_mask0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_14, %primals_15)), kwargs = {}) +# return %getitem +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'in_ptr9': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128*ks1, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128*ks1, 128, 1 + + ZQ = 8 + HQ = 32 + Q_LEN = ks0 + ZKV = 8 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 8 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m + 4096*idx_zq*ks0, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr9 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = ks5 + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21, primals_22, primals_23, primals_24, primals_25, primals_26, primals_27, primals_28, primals_29, primals_30 = args + args.clear() + s50 = primals_1 + s0 = primals_3 + s43 = primals_5 + s22 = primals_7 + s72 = primals_8 + s37 = primals_10 + s71 = primals_11 + s99 = primals_12 + s75 = primals_15 + s94 = primals_16 + s28 = primals_18 + s4 = primals_19 + s56 = primals_21 + s84 = primals_23 + s53 = primals_24 + s100 = primals_26 + s6 = primals_28 + s10 = primals_29 + assert_size_stride(primals_2, (8, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(primals_4, (8, 8, s0, 128), (1024*s0, 128*s0, 128, 1)) + assert_size_stride(primals_6, (8, 8, s0, 128), (1024*s0, 128*s0, 128, 1)) + assert_size_stride(primals_9, (8, 1, s22, s72), (s22*s72, s22*s72, s72, 1)) + assert_size_stride(primals_13, (8, 1, s99), (s99, s99, 1)) + assert_size_stride(primals_14, (8, ), (1, )) + assert_size_stride(primals_17, (8, 1, s94), (s94, s94, 1)) + assert_size_stride(primals_20, (8, 1, s28, s4), (s28*s4, s28*s4, s4, 1)) + assert_size_stride(primals_22, (8, 1, s56), (s56, s56, 1)) + assert_size_stride(primals_25, (8, 1, s84, s53), (s53*s84, s53*s84, s53, 1)) + assert_size_stride(primals_27, (8, 1, s100), (s100, s100, 1)) + assert_size_stride(primals_30, (8, 1, s6, s10), (s10*s6, s10*s6, s10, 1)) + with torch.cuda._DeviceGuard(0): + torch.cuda.set_device(0) + buf0 = empty_strided_cuda((8, 32, s37), (32*s37, s37, 1), torch.float32) + buf1 = empty_strided_cuda((8, 32, s37), (32*s37, s37, 1), torch.float32) + buf2 = empty_strided_cuda((8, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream0 = get_raw_stream(0) + triton_tem_fused_0.run(primals_2, primals_4, primals_6, buf0, buf1, primals_13, primals_9, primals_17, primals_20, primals_14, buf2, s37, s0, s99, s22, s72, s75, (127 + s37) // 128, 8, 32, stream=stream0) + del buf1 + return (buf2, primals_2, primals_4, primals_6, primals_9, primals_13, primals_14, primals_17, primals_20, primals_22, primals_25, primals_27, primals_30, buf2, buf0, s37, s0, s75, s22, s72, s99, s94, s28, s4, s56, s53, s84, s100, s10, s6, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = 2009 + primals_2 = rand_strided((8, 32, 2009, 128), (8228864, 128, 4096, 1), device='cuda:0', dtype=torch.bfloat16) + primals_3 = 2009 + primals_4 = rand_strided((8, 8, 2009, 128), (2057216, 257152, 128, 1), device='cuda:0', dtype=torch.bfloat16) + primals_5 = 2009 + primals_6 = rand_strided((8, 8, 2009, 128), (2057216, 257152, 128, 1), device='cuda:0', dtype=torch.bfloat16) + primals_7 = 16 + primals_8 = 16 + primals_9 = rand_strided((8, 1, 16, 16), (256, 256, 16, 1), device='cuda:0', dtype=torch.int32) + primals_10 = 2009 + primals_11 = 2009 + primals_12 = 16 + primals_13 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:0', dtype=torch.int32) + primals_14 = rand_strided((8, ), (1, ), device='cuda:0', dtype=torch.int64) + primals_15 = 2009 + primals_16 = 16 + primals_17 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:0', dtype=torch.int32) + primals_18 = 16 + primals_19 = 16 + primals_20 = rand_strided((8, 1, 16, 16), (256, 256, 16, 1), device='cuda:0', dtype=torch.int32) + primals_21 = 16 + primals_22 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:0', dtype=torch.int32) + primals_23 = 16 + primals_24 = 16 + primals_25 = rand_strided((8, 1, 16, 16), (256, 256, 16, 1), device='cuda:0', dtype=torch.int32) + primals_26 = 16 + primals_27 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:0', dtype=torch.int32) + primals_28 = 16 + primals_29 = 16 + primals_30 = rand_strided((8, 1, 16, 16), (256, 256, 16, 1), device='cuda:0', dtype=torch.int32) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21, primals_22, primals_23, primals_24, primals_25, primals_26, primals_27, primals_28, primals_29, primals_30]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/fe/1daa8ded4b82d70d3439801170abe262c13bdfc4c165f04dbb440bc1e877c465.best_config b/SpecForge-ext/cache/compiled_kernels/fe/1daa8ded4b82d70d3439801170abe262c13bdfc4c165f04dbb440bc1e877c465.best_config new file mode 100644 index 0000000000000000000000000000000000000000..e9b96a126fb37b684d7d003c0adf1a0efd4c8fc6 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/fe/1daa8ded4b82d70d3439801170abe262c13bdfc4c165f04dbb440bc1e877c465.best_config @@ -0,0 +1 @@ +{"XBLOCK": 512, "num_warps": 8, "num_stages": 1, "configs_hash": "3ca5c3e34d35093f3c9ab2829a9faeebad5e61c4ca13d5ed6053d7b71ce60d5a", "found_by_coordesc": false, "time_taken_ms": 20, "triton_cache_hash": "Z2RWAHMO7VUWQKIIRA5A46JYV2SEXHWLKREQM7TOP6VGUWDXAYAQ"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/fe/cfexazzkclzxl7gtdsot7ykn5ddtljv4fho2zkbbz3m4i5sjs3qx.py b/SpecForge-ext/cache/compiled_kernels/fe/cfexazzkclzxl7gtdsot7ykn5ddtljv4fho2zkbbz3m4i5sjs3qx.py new file mode 100644 index 0000000000000000000000000000000000000000..6542358bb3d49efadc887b8ebac4182a5c41aa4d --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/fe/cfexazzkclzxl7gtdsot7ykn5ddtljv4fho2zkbbz3m4i5sjs3qx.py @@ -0,0 +1,56 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 4194304}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'ks4': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, ks4, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x4 = xindex + x2 = ((xindex // ks0) % ks1) + x0 = (xindex % ks3) + x5 = xindex // ks3 + tmp0 = tl.load(in_ptr0 + (x4), xmask, eviction_policy='evict_last').to(tl.float32) + tmp1 = tl.load(in_ptr1 + (x2), xmask, eviction_policy='evict_last') + tmp2 = ks2 + tmp3 = tmp1 + tmp2 + tmp4 = tmp1 < 0 + tmp5 = tl.where(tmp4, tmp3, tmp1) + tl.device_assert(((0 <= tmp5) & (tmp5 < ks2)) | ~(xmask), "index out of bounds: 0 <= tmp5 < ks2") + tmp7 = tl.load(in_ptr2 + (x0 + ks3*tmp5), xmask, eviction_policy='evict_last').to(tl.float32) + tmp8 = tmp0 * tmp7 + tmp9 = x0 + tmp10 = tl.full([1], 0, tl.int64) + tmp11 = tmp9 >= tmp10 + tmp12 = ks3 + (-1)*(ks3 // 2) + tmp13 = tmp9 < tmp12 + tmp14 = tl.load(in_ptr0 + (ks3*x5 + (ks3 // 2) + (x0)), tmp13 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp15 = -tmp14 + tmp16 = tl.full(tmp15.shape, 0.0, tmp15.dtype) + tmp17 = tl.where(tmp13, tmp15, tmp16) + tmp18 = tmp9 >= tmp12 + tmp19 = ks3 + tmp20 = tmp9 < tmp19 + tmp21 = tl.load(in_ptr0 + (ks3*x5 + (x0 + ((-1)*ks3) + (ks3 // 2))), tmp18 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp22 = tl.where(tmp13, tmp17, tmp21) + tmp23 = ks4 + tmp24 = tmp1 + tmp23 + tmp25 = tl.where(tmp4, tmp24, tmp1) + tl.device_assert(((0 <= tmp25) & (tmp25 < ks4)) | ~(xmask), "index out of bounds: 0 <= tmp25 < ks4") + tmp27 = tl.load(in_ptr3 + (x0 + ks3*tmp25), xmask, eviction_policy='evict_last').to(tl.float32) + tmp28 = tmp22 * tmp27 + tmp29 = tmp8 + tmp28 + tl.store(out_ptr0 + (x4), tmp29, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/fm/cfmxabgahcooc23j5hyatoooxsnwvxl636joiaolujw2olqnzhuf.py b/SpecForge-ext/cache/compiled_kernels/fm/cfmxabgahcooc23j5hyatoooxsnwvxl636joiaolujw2olqnzhuf.py new file mode 100644 index 0000000000000000000000000000000000000000..c1497485cdc01a7e96a0b6f4cbcd87795220938b --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/fm/cfmxabgahcooc23j5hyatoooxsnwvxl636joiaolujw2olqnzhuf.py @@ -0,0 +1,161 @@ +# AOT ID: ['11_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/vk/cvkcbju4ftdjugozv3aumhlgwacbn2h4ae4bwnnofexgmrt5upru.py +# Topologically Sorted Source Nodes: [target_head, target_p], Original ATen: [aten._to_copy, prims.prepare_softmax_online, aten.sub, aten.exp, aten._softmax] +# Source node to ATen node mapping: +# target_head => convert_element_type +# target_p => div +# Graph fragment: +# %arg1_1 : Tensor "bf16[8, s67, 32000][32000*s67, 32000, 1]cuda:0" = PlaceHolder[target=arg1_1] +# %getitem : Tensor "f32[8, s67, 1][s67, 1, 8*s67]cuda:0" = PlaceHolder[target=getitem] +# %getitem_1 : Tensor "f32[8, s67, 1][s67, 1, 8*s67]cuda:0" = PlaceHolder[target=getitem_1] +# %convert_element_type : Tensor "f32[8, s67, 32000][32000*s67, 32000, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%arg1_1, torch.float32), kwargs = {}) +# %prepare_softmax_online_default : [num_users=2] = call_function[target=torch.ops.prims.prepare_softmax_online.default](args = (%convert_element_type, 2), kwargs = {}) +# %sub_tensor : Tensor "f32[8, s67, 32000][32000*s67, 32000, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%convert_element_type, %getitem), kwargs = {}) +# %exp_default : Tensor "f32[8, s67, 32000][32000*s67, 32000, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.exp.default](args = (%sub_tensor,), kwargs = {}) +# %div : Tensor "f32[8, s67, 32000][32000*s67, 32000, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%exp_default, %getitem_1), kwargs = {}) +# return %getitem,%getitem_1,%div +triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0 = async_compile.triton('triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 16384, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr2': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0(in_ptr0, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp3_max = tl.full([XBLOCK, R0_BLOCK], float('-inf'), tl.float32) + _tmp3_sum = tl.zeros([XBLOCK, R0_BLOCK], tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp1 = tmp0.to(tl.float32) + tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) + + _tmp3_max_next, _tmp3_sum_next = triton_helpers.online_softmax_combine( + _tmp3_max, _tmp3_sum, tmp2, False + ) + + _tmp3_max = tl.where(r0_mask & xmask, _tmp3_max_next, _tmp3_max) + _tmp3_sum = tl.where(r0_mask & xmask, _tmp3_sum_next, _tmp3_sum) + + tmp3, tmp4 = triton_helpers.online_softmax_reduce( + _tmp3_max, _tmp3_sum, 1, False) + tmp3 = tmp3[:, None] + tmp4 = tmp4[:, None] + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp5 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp6 = tmp5.to(tl.float32) + tmp7 = tmp6 - tmp3 + tmp8 = libdevice.exp(tmp7) + tmp9 = (tmp8 / tmp4) + tl.store(out_ptr2 + (r0_1 + 32000*x0), tmp9, r0_mask & xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1 = args + args.clear() + s67 = arg0_1 + assert_size_stride(arg1_1, (8, s67, 32000), (32000*s67, 32000, 1)) + with torch.cuda._DeviceGuard(0): + torch.cuda.set_device(0) + buf2 = empty_strided_cuda((8, s67, 32000), (32000*s67, 32000, 1), torch.float32) + # Topologically Sorted Source Nodes: [target_head, target_p], Original ATen: [aten._to_copy, prims.prepare_softmax_online, aten.sub, aten.exp, aten._softmax] + triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0_xnumel = 8*s67 + stream0 = get_raw_stream(0) + triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0.run(arg1_1, buf2, triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0_xnumel, 32000, stream=stream0) + del arg1_1 + return (buf2, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = 2009 + arg1_1 = rand_strided((8, 2009, 32000), (64288000, 32000, 1), device='cuda:0', dtype=torch.bfloat16) + fn = lambda: call([arg0_1, arg1_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/fr/cfrczc5cr5uotil5g5x435datuzfao56zz4vsxlh33jteluxhhme.py b/SpecForge-ext/cache/compiled_kernels/fr/cfrczc5cr5uotil5g5x435datuzfao56zz4vsxlh33jteluxhhme.py new file mode 100644 index 0000000000000000000000000000000000000000..e4a87af1bf5b5709473dd0bbfd59f43e4011b5db --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/fr/cfrczc5cr5uotil5g5x435datuzfao56zz4vsxlh33jteluxhhme.py @@ -0,0 +1,25 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 4096}, + filename=__file__, + triton_meta={'signature': {'out_ptr0': '*i32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_new_zeros_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 17408}}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_new_zeros_1(out_ptr0, xnumel, XBLOCK : tl.constexpr): + xnumel = 2176 + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.full([1], 0, tl.int32) + tl.store(out_ptr0 + (x0), tmp0, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/fr/df86db1aa782529bbee54b1dae6ae8836a0d24759a54f5db65b8185538580f0c.best_config b/SpecForge-ext/cache/compiled_kernels/fr/df86db1aa782529bbee54b1dae6ae8836a0d24759a54f5db65b8185538580f0c.best_config new file mode 100644 index 0000000000000000000000000000000000000000..7d56ea7451f6ff3ceffec392bc015b86ab20533e --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/fr/df86db1aa782529bbee54b1dae6ae8836a0d24759a54f5db65b8185538580f0c.best_config @@ -0,0 +1 @@ +{"XBLOCK": 256, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 11, "triton_cache_hash": "4UWYNBR3KPWQGNAZ5LIIRE7YAZWTQP4CP3JS6GOSLWYDF5K7WTAA"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/ge/cge7pldzjp233r253br5rjlpsooeme3wcmy44ghmepsv2frglxda.py b/SpecForge-ext/cache/compiled_kernels/ge/cge7pldzjp233r253br5rjlpsooeme3wcmy44ghmepsv2frglxda.py new file mode 100644 index 0000000000000000000000000000000000000000..11a883773aa201b3e8376eb27e00dd3260a193f1 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/ge/cge7pldzjp233r253br5rjlpsooeme3wcmy44ghmepsv2frglxda.py @@ -0,0 +1,527 @@ +# AOT ID: ['8_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/u5/cu5vbm4k5rx2ckzgfjj47hdlzuvwn5xanjqx3duors7zpk22vecs.py +# Topologically Sorted Source Nodes: [dense_mask_2], Original ATen: [aten.new_zeros] +# Source node to ATen node mapping: +# dense_mask_2 => full_default_1 +# Graph fragment: +# %full_default_1 : Tensor "i32[8, 1, 16, (((s37 + 127)//128)) + 1][16*Max(1, (((s37 + 127)//128)) + 1), 16*Max(1, (((s37 + 127)//128)) + 1), Max(1, (((s37 + 127)//128)) + 1), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 1, 16, %add_166], 0), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:4, pin_memory: False}) +# return %index_put +triton_poi_fused_new_zeros_0 = async_compile.triton('triton_poi_fused_new_zeros_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 8192}, + filename=__file__, + triton_meta={'signature': {'out_ptr0': '*i32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_new_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_new_zeros_0(out_ptr0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.full([1], 0, tl.int32) + tl.store(out_ptr0 + (x0), tmp0, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sh/cshnc2xfeqljlx2yygmlwikvuwrqaw3mjbtw2iod3tdoksj22xly.py +# Topologically Sorted Source Nodes: [result_1, m, causal_mask, n, b, index, lt, padding_mask, index_1, lt_1, and_2, suffix_mask, remainder, index_2, padding_mask_1, and_3, and_4, sub, remainder_1, diagnol_mask, result_2, batched_outputs_2, mask_1, mask_2, mask_3, mask_block_sum, gt, lt_3, partial_blocks, partial_blocks_1, dense_mask, full_blocks, full_blocks_1, dense_mask_1], Original ATen: [aten.view, aten.arange, aten.ge, aten.index, aten.lt, aten.bitwise_and, aten.bitwise_or, aten.remainder, aten.sub, aten.eq, aten.constant_pad_nd, aten.permute, aten.sum, aten.gt, aten._to_copy] +# Source node to ATen node mapping: +# and_2 => bitwise_and_1 +# and_3 => bitwise_and_2 +# and_4 => bitwise_and_3, view_8 +# b => iota +# batched_outputs_2 => view_9 +# causal_mask => ge_1, view +# dense_mask => convert_element_type_2 +# dense_mask_1 => convert_element_type_5 +# diagnol_mask => eq_12 +# full_blocks => eq_24 +# full_blocks_1 => convert_element_type_1 +# gt => gt +# index => index +# index_1 => index_1 +# index_2 => index_2 +# lt => lt, view_1 +# lt_1 => lt_1, view_2 +# lt_3 => lt_3 +# m => iota_2 +# mask_1 => constant_pad_nd +# mask_2 => view_10 +# mask_3 => permute +# mask_block_sum => sum_1 +# n => iota_3 +# padding_mask => bitwise_and, view_3, view_4 +# padding_mask_1 => lt_2, view_6 +# partial_blocks => bitwise_and_4 +# partial_blocks_1 => convert_element_type +# remainder => remainder +# remainder_1 => remainder_1 +# result_1 => bitwise_or, full_default +# result_2 => bitwise_or_1 +# sub => sub_12, view_7 +# suffix_mask => ge_2 +# Graph fragment: +# %arg1_1 : Tensor "i64[8][1]cuda:4" = PlaceHolder[target=arg1_1] +# %sum_1 : Tensor "i64[8, 1, 16, ((s37 + 127)//128)][16*(((s37 + 127)//128)), 128*(((s37 + 127)//128)), ((s37 + 127)//128), 1]cuda:4" = PlaceHolder[target=sum_1] +# %full_default : Tensor "b8[8, 1, 1][1, 1, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 1, 1], False), kwargs = {dtype: torch.bool, layout: torch.strided, device: cuda:4, pin_memory: False}) +# %iota_2 : Tensor "i64[2048][1]cuda:4"[num_users=3] = call_function[target=torch.ops.prims.iota.default](args = (2048,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:4, requires_grad: False}) +# %view : Tensor "i64[2048, 1][1, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%iota_2, [2048, 1]), kwargs = {}) +# %iota_3 : Tensor "i64[s37][1]cuda:4"[num_users=5] = call_function[target=torch.ops.prims.iota.default](args = (%arg0_1,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:4, requires_grad: False}) +# %ge_1 : Tensor "b8[2048, s37][Max(1, s37), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.ge.Tensor](args = (%view, %iota_3), kwargs = {}) +# %iota : Tensor "i64[8][1]cuda:4"[num_users=3] = call_function[target=torch.ops.prims.iota.default](args = (8,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:4, requires_grad: False}) +# %index : Tensor "i64[8][1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg1_1, [%iota]), kwargs = {}) +# %view_1 : Tensor "i64[8, 1][1, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%index, [8, 1]), kwargs = {}) +# %lt : Tensor "b8[8, s37][Max(1, s37), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_3, %view_1), kwargs = {}) +# %view_4 : Tensor "b8[8, 1, s37][Max(1, s37), s37, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%lt, [8, 1, %arg0_1]), kwargs = {}) +# %index_1 : Tensor "i64[8][1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg1_1, [%iota]), kwargs = {}) +# %view_2 : Tensor "i64[8, 1][1, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%index_1, [8, 1]), kwargs = {}) +# %lt_1 : Tensor "b8[8, 2048][2048, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_2, %view_2), kwargs = {}) +# %view_3 : Tensor "b8[8, 2048, 1][2048, 1, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%lt_1, [8, 2048, 1]), kwargs = {}) +# %bitwise_and : Tensor "b8[8, 2048, s37][2048*Max(1, s37), Max(1, s37), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%view_4, %view_3), kwargs = {}) +# %bitwise_and_1 : Tensor "b8[8, 2048, s37][2048*Max(1, s37), Max(1, s37), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%ge_1, %bitwise_and), kwargs = {}) +# %bitwise_or : Tensor "b8[8, 2048, s37][2048*Max(1, s37), Max(1, s37), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.bitwise_or.Tensor](args = (%full_default, %bitwise_and_1), kwargs = {}) +# %ge_2 : Tensor "b8[s37][1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.ge.Scalar](args = (%iota_3, 2048), kwargs = {}) +# %remainder : Tensor "i64[s37][1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.remainder.Scalar](args = (%iota_3, 2048), kwargs = {}) +# %index_2 : Tensor "i64[8][1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg1_1, [%iota]), kwargs = {}) +# %view_6 : Tensor "i64[8, 1][1, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%index_2, [8, 1]), kwargs = {}) +# %lt_2 : Tensor "b8[8, s37][Max(1, s37), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%remainder, %view_6), kwargs = {}) +# %bitwise_and_2 : Tensor "b8[8, s37][Max(1, s37), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%ge_2, %lt_2), kwargs = {}) +# %view_8 : Tensor "b8[8, 1, s37][Max(1, s37), s37, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%bitwise_and_2, [8, 1, %arg0_1]), kwargs = {}) +# %view_7 : Tensor "i64[2048, 1][1, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%iota_2, [2048, 1]), kwargs = {}) +# %sub_12 : Tensor "i64[2048, s37][Max(1, s37), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%iota_3, %view_7), kwargs = {}) +# %remainder_1 : Tensor "i64[2048, s37][Max(1, s37), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.remainder.Scalar](args = (%sub_12, 2048), kwargs = {}) +# %eq_12 : Tensor "b8[2048, s37][Max(1, s37), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%remainder_1, 0), kwargs = {}) +# %bitwise_and_3 : Tensor "b8[8, 2048, s37][2048*Max(1, s37), Max(1, s37), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%view_8, %eq_12), kwargs = {}) +# %bitwise_or_1 : Tensor "b8[8, 2048, s37][2048*Max(1, s37), Max(1, s37), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.bitwise_or.Tensor](args = (%bitwise_or, %bitwise_and_3), kwargs = {}) +# %view_9 : Tensor "b8[8, 1, 2048, s37][2048*Max(1, s37), 2048*Max(1, s37), Max(1, s37), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%bitwise_or_1, [8, 1, 2048, %arg0_1]), kwargs = {}) +# %constant_pad_nd : Tensor "b8[8, 1, 2048, 128*(((s37 + 127)//128))][2048*Max(1, 128*(((s37 + 127)//128))), 2048*Max(1, 128*(((s37 + 127)//128))), Max(1, 128*(((s37 + 127)//128))), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.constant_pad_nd.default](args = (%expand, [0, %sub_23, 0, 0], 0.0), kwargs = {}) +# %view_10 : Tensor "b8[8, 1, 16, 128, ((s37 + 127)//128), 128][2048*Max(1, 128*(((s37 + 127)//128))), 2048*Max(1, 128*(((s37 + 127)//128))), 128*Max(1, 128*(((s37 + 127)//128))), Max(1, 128*(((s37 + 127)//128))), 128, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%constant_pad_nd, [8, 1, 16, 128, %floordiv_1, 128]), kwargs = {}) +# %permute : Tensor "b8[8, 1, 16, ((s37 + 127)//128), 128, 128][2048*Max(1, 128*(((s37 + 127)//128))), 2048*Max(1, 128*(((s37 + 127)//128))), 128*Max(1, 128*(((s37 + 127)//128))), 128, Max(1, 128*(((s37 + 127)//128))), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_10, [0, 1, 2, 4, 3, 5]), kwargs = {}) +# %sum_1 : Tensor "i64[8, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:4"[num_users=3] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%permute, [-2, -1]), kwargs = {}) +# %gt : Tensor "b8[8, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {}) +# %lt_3 : Tensor "b8[8, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.lt.Scalar](args = (%sum_1, 16384), kwargs = {}) +# %bitwise_and_4 : Tensor "b8[8, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%gt, %lt_3), kwargs = {}) +# %convert_element_type : Tensor "i8[8, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%bitwise_and_4, torch.int8), kwargs = {}) +# %convert_element_type_2 : Tensor "i32[8, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:4"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%convert_element_type, torch.int32), kwargs = {}) +# %eq_24 : Tensor "b8[8, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%sum_1, 16384), kwargs = {}) +# %convert_element_type_1 : Tensor "i8[8, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%eq_24, torch.int8), kwargs = {}) +# %convert_element_type_5 : Tensor "i32[8, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:4"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%convert_element_type_1, torch.int32), kwargs = {}) +# return %sum_1,%convert_element_type_2,%convert_element_type_5 +triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1 = async_compile.triton('triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 16384}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr1': '*i32', 'out_ptr2': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1(in_ptr0, out_ptr1, out_ptr2, ks0, ks1, ks2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 16384 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % ks0) + x1 = ((xindex // ks0) % 16) + x2 = xindex // ks2 + _tmp36 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + x5 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_3 = (r0_index % 128) + r0_4 = r0_index // 128 + tmp0 = r0_3 + 128*x0 + tmp1 = ks1 + tmp2 = tmp0 < tmp1 + tmp3 = r0_4 + 128*x1 + tmp4 = r0_3 + 128*x0 + tmp5 = tmp3 >= tmp4 + tmp6 = tl.load(in_ptr0 + (tl.broadcast_to(x2, [XBLOCK, R0_BLOCK])), r0_mask & tmp2 & xmask, eviction_policy='evict_last', other=0.0) + tmp7 = tmp4 < tmp6 + tmp8 = tmp3 < tmp6 + tmp9 = tmp7 & tmp8 + tmp10 = tmp5 & tmp9 + tmp11 = tl.full([1, 1], False, tl.int1) + tmp12 = tmp11 | tmp10 + tmp13 = tl.full([1, 1], 2048, tl.int64) + tmp14 = tmp4 >= tmp13 + tmp15 = ((r0_3 + 128*x0) % 2048) + tmp16 = tmp15 < tmp6 + tmp17 = tmp14 & tmp16 + tmp18 = r0_3 + ((-1)*r0_4) + ((-128)*x1) + 128*x0 + tmp19 = (tmp18 % tmp13) + tmp20 = tl.full([1, 1], 0, tl.int32) + tmp21 = tmp19 != tmp20 + tmp22 = (libdevice.signbit(tmp19) != 0) if (tmp19).dtype is tl.float32 else tmp19 < 0 + tmp23 = (libdevice.signbit(tmp13) != 0) if (tmp13).dtype is tl.float32 else tmp13 < 0 + tmp24 = tmp22 != tmp23 + tmp25 = tmp21 & tmp24 + tmp26 = tmp19 + tmp13 + tmp27 = tl.where(tmp25, tmp26, tmp19) + tmp28 = tl.full([1, 1], 0, tl.int64) + tmp29 = tmp27 == tmp28 + tmp30 = tmp17 & tmp29 + tmp31 = tmp12 | tmp30 + tmp32 = tl.full(tmp31.shape, False, tmp31.dtype) + tmp33 = tl.where(tmp2, tmp31, tmp32) + tmp34 = tmp33.to(tl.int64) + tmp35 = tl.broadcast_to(tmp34, [XBLOCK, R0_BLOCK]) + tmp37 = _tmp36 + tmp35 + _tmp36 = tl.where(r0_mask & xmask, tmp37, _tmp36) + tmp36 = tl.sum(_tmp36, 1)[:, None] + tmp38 = tl.full([1, 1], 0, tl.int64) + tmp39 = tmp36 > tmp38 + tmp40 = tl.full([1, 1], 16384, tl.int64) + tmp41 = tmp36 < tmp40 + tmp42 = tmp39 & tmp41 + tmp43 = tmp42.to(tl.int8) + tmp44 = tmp43.to(tl.int32) + tmp45 = tmp36 == tmp40 + tmp46 = tmp45.to(tl.int8) + tmp47 = tmp46.to(tl.int32) + tl.store(out_ptr1 + (x5), tmp44, xmask) + tl.store(out_ptr2 + (x5), tmp47, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xy/cxyrsscd2x5irpnkxj53kruljpnjhmfn2h2idqa3d2fme5nha4jg.py +# Topologically Sorted Source Nodes: [dense_mask_2, setitem, arange_4, row_indices, col_range, num_blocks_in_row, child_3, unsqueeze_1, index_mask, child_4, valid_indices], Original ATen: [aten.new_zeros, aten.arange, aten.unsqueeze, aten.sum, aten._to_copy, aten.lt, aten.scalar_tensor, aten.where, aten.view, aten.index_put] +# Source node to ATen node mapping: +# arange_4 => iota_4 +# child_3 => convert_element_type_3 +# child_4 => convert_element_type_4 +# col_range => iota_5 +# dense_mask_2 => full_default_1 +# index_mask => lt_4 +# num_blocks_in_row => sum_2 +# row_indices => unsqueeze +# setitem => full_default_2, index_put, iota_6, iota_7, unsqueeze_2, unsqueeze_3, unsqueeze_4, unsqueeze_5, unsqueeze_6 +# unsqueeze_1 => unsqueeze_1 +# valid_indices => scalar_tensor, where +# Graph fragment: +# %convert_element_type_2 : Tensor "i32[8, 1, 16, ((s37 + 127)//128)][16*(((s37 + 127)//128)), 128*(((s37 + 127)//128)), ((s37 + 127)//128), 1]cuda:4" = PlaceHolder[target=convert_element_type_2] +# %sum_2 : Tensor "i64[8, 1, 16][16, 128, 1]cuda:4" = PlaceHolder[target=sum_2] +# %getitem_1 : Tensor "i64[8, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 128*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:4" = PlaceHolder[target=getitem_1] +# %convert_element_type_3 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:4" = PlaceHolder[target=convert_element_type_3] +# %convert_element_type_4 : Tensor "i32[8, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:4" = PlaceHolder[target=convert_element_type_4] +# %index_put : Tensor "i32[8, 1, 16, (((s37 + 127)//128)) + 1][16*(((s37 + 127)//128)) + 16, 16*(((s37 + 127)//128)) + 16, (((s37 + 127)//128)) + 1, 1]cuda:4" = PlaceHolder[target=index_put] +# %full_default_1 : Tensor "i32[8, 1, 16, (((s37 + 127)//128)) + 1][16*Max(1, (((s37 + 127)//128)) + 1), 16*Max(1, (((s37 + 127)//128)) + 1), Max(1, (((s37 + 127)//128)) + 1), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 1, 16, %add_166], 0), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:4, pin_memory: False}) +# %iota_7 : Tensor "i64[8][1]cuda:4"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (8,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:4, requires_grad: False}) +# %unsqueeze_4 : Tensor "i64[8, 1][1, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_7, -1), kwargs = {}) +# %unsqueeze_5 : Tensor "i64[8, 1, 1][1, 1, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_4, -1), kwargs = {}) +# %unsqueeze_6 : Tensor "i64[8, 1, 1, 1][1, 1, 1, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_5, -1), kwargs = {}) +# %iota_6 : Tensor "i64[1][1]cuda:4"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (1,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:4, requires_grad: False}) +# %unsqueeze_2 : Tensor "i64[1, 1][1, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_6, -1), kwargs = {}) +# %unsqueeze_3 : Tensor "i64[1, 1, 1][1, 1, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_2, -1), kwargs = {}) +# %iota_4 : Tensor "i32[16][1]cuda:4"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (16,), kwargs = {start: 0, step: 1, dtype: torch.int32, device: cuda:4, requires_grad: False}) +# %unsqueeze : Tensor "i32[16, 1][1, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_4, -1), kwargs = {}) +# %iota_5 : Tensor "i32[((s37 + 127)//128)][1]cuda:4"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (%floordiv_1,), kwargs = {start: 0, step: 1, dtype: torch.int32, device: cuda:4, requires_grad: False}) +# %sum_2 : Tensor "i64[8, 1, 16][16, 16, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%convert_element_type_2, [-1]), kwargs = {}) +# %convert_element_type_3 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:4"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_2, torch.int32), kwargs = {}) +# %unsqueeze_1 : Tensor "i32[8, 1, 16, 1][16, 16, 1, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_3, 3), kwargs = {}) +# %lt_4 : Tensor "b8[8, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_5, %unsqueeze_1), kwargs = {}) +# %convert_element_type_4 : Tensor "i32[8, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:4"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_1, torch.int32), kwargs = {}) +# %scalar_tensor : Tensor "i32[][]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.scalar_tensor.default](args = (%floordiv_1,), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:4}) +# %where : Tensor "i32[8, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%lt_4, %convert_element_type_4, %scalar_tensor), kwargs = {}) +# %full_default_2 : Tensor "i32[8, 1, 1, 1][1, 1, 1, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 1, 1, 1], 1), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:4, pin_memory: False}) +# %index_put : Tensor "i32[8, 1, 16, (((s37 + 127)//128)) + 1][16*Max(1, (((s37 + 127)//128)) + 1), 16*Max(1, (((s37 + 127)//128)) + 1), Max(1, (((s37 + 127)//128)) + 1), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.index_put_.default](args = (%full_default_1, [%unsqueeze_6, %unsqueeze_3, %unsqueeze, %where], %full_default_2), kwargs = {}) +# return %sum_2,%convert_element_type_3,%convert_element_type_4,%buf13 +triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2 = async_compile.triton('triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 128, 'r0_': 32}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'in_ptr1': '*i64', 'out_ptr1': '*i32', 'out_ptr2': '*i32', 'out_ptr3': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2', 'mutated_arg_names': ['out_ptr3'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2(in_ptr0, in_ptr1, out_ptr1, out_ptr2, out_ptr3, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp3 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp1 = tmp0.to(tl.int64) + tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) + tmp4 = _tmp3 + tmp2 + _tmp3 = tl.where(r0_mask & xmask, tmp4, _tmp3) + tmp3 = tl.sum(_tmp3, 1)[:, None] + tmp5 = tmp3.to(tl.int32) + tl.store(out_ptr1 + (x0), tmp5, xmask) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp6 = tl.load(in_ptr1 + (r0_1 + x0*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp7 = tmp6.to(tl.int32) + tmp8 = r0_1 + tmp9 = tmp8 < tmp5 + tmp10 = ks0 + tmp11 = tl.where(tmp9, tmp7, tmp10) + tmp12 = 1 + ks0 + tmp13 = tmp11 + tmp12 + tmp14 = tmp11 < 0 + tmp15 = tl.where(tmp14, tmp13, tmp11) + tl.device_assert(((0 <= tmp15) & (tmp15 < 1 + (triton_helpers.div_floor_integer(127 + ks1, 128)))) | ~(r0_mask & xmask), "index out of bounds: 0 <= tmp15 < 1 + (triton_helpers.div_floor_integer(127 + ks1, 128))") + tmp17 = tl.full([1, 1], 1, tl.int32) + tl.store(out_ptr2 + (r0_1 + x0*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp7, r0_mask & xmask) + tl.store(out_ptr3 + (tl.broadcast_to(tmp15 + x0 + ks0*x0, [XBLOCK, R0_BLOCK])), tmp17, r0_mask & xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bn/cbnlhjvdilhpyvbq45tudsctw5loi5h3hcwqzatfbvx3vqngs73c.py +# Topologically Sorted Source Nodes: [batched_outputs_3, transpose, col_indices_2, q_indices, num_blocks_in_row_2, q_num_blocks], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sort, aten._to_copy, aten.sum] +# Source node to ATen node mapping: +# batched_outputs_3 => clone_4, slice_2 +# col_indices_2 => sort_2 +# num_blocks_in_row_2 => sum_4 +# q_indices => clone_6, convert_element_type_9 +# q_num_blocks => convert_element_type_8 +# transpose => permute_1 +# Graph fragment: +# %buf13 : Tensor "i32[8, 1, 16, (((s37 + 127)//128)) + 1][16*(((s37 + 127)//128)) + 16, 16*(((s37 + 127)//128)) + 16, (((s37 + 127)//128)) + 1, 1]cuda:4" = PlaceHolder[target=buf13] +# %buf15 : Tensor "i16[8, 1, ((s37 + 127)//128), 16][16*(((s37 + 127)//128)), 128*(((s37 + 127)//128)), 16, 1]cuda:4" = PlaceHolder[target=buf15] +# %sum_4 : Tensor "i64[8, 1, ((s37 + 127)//128)][((s37 + 127)//128), 8*(((s37 + 127)//128)), 1]cuda:4" = PlaceHolder[target=sum_4] +# %slice_2 : Tensor "i32[8, 1, 16, ((s37 + 127)//128)][16*Max(1, (((s37 + 127)//128)) + 1), 16*Max(1, (((s37 + 127)//128)) + 1), Max(1, (((s37 + 127)//128)) + 1), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%index_put, 3, 0, %floordiv_1), kwargs = {}) +# %clone_4 : Tensor "i32[8, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%slice_2,), kwargs = {memory_format: torch.contiguous_format}) +# %permute_1 : Tensor "i32[8, 1, ((s37 + 127)//128), 16][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), 1, Max(1, ((s37 + 127)//128))]cuda:4"[num_users=2] = call_function[target=torch.ops.aten.permute.default](args = (%clone_4, [0, 1, 3, 2]), kwargs = {}) +# %sort_2 : [num_users=1] = call_function[target=torch.ops.aten.sort.stable](args = (%permute_1,), kwargs = {stable: True, descending: True}) +# %convert_element_type_9 : Tensor "i32[8, 1, ((s37 + 127)//128), 16][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), 1, Max(1, ((s37 + 127)//128))]cuda:4"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_5, torch.int32), kwargs = {}) +# %clone_6 : Tensor "i32[8, 1, ((s37 + 127)//128), 16][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), 16, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%convert_element_type_9,), kwargs = {memory_format: torch.contiguous_format}) +# %sum_4 : Tensor "i64[8, 1, ((s37 + 127)//128)][Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%permute_1, [-1]), kwargs = {}) +# %convert_element_type_8 : Tensor "i32[8, 1, ((s37 + 127)//128)][Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_4, torch.int32), kwargs = {}) +# return %buf15,%sum_4,%clone_6,%convert_element_type_8 +triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3 = async_compile.triton('triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 256, 'r0_': 16}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr2': '*i32', 'out_ptr3': '*i32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3(in_ptr0, out_ptr2, out_ptr3, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 16 + R0_BLOCK: tl.constexpr = 16 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + x0 = (xindex % ks0) + x1 = xindex // ks0 + x3 = xindex + tmp0 = tl.load(in_ptr0 + (r0_2 + x0 + 16*x1 + ks0*r0_2 + 16*ks0*x1), xmask, eviction_policy='evict_last', other=0.0) + tmp1 = r0_2 + tmp2 = tmp1.to(tl.int16) + tmp3 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp4 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5, tmp6, = triton_helpers.sort_with_index(tmp3, tmp4, None, 1, stable=True, descending=True) + tmp7 = tmp0.to(tl.int64) + tmp8 = tl.broadcast_to(tmp7, [XBLOCK, R0_BLOCK]) + tmp10 = tl.where(xmask, tmp8, 0) + tmp11 = tl.sum(tmp10, 1)[:, None].to(tl.int64) + tmp12 = tmp6.to(tl.int64) + tmp13 = tmp12.to(tl.int32) + tmp14 = tmp11.to(tl.int32) + tl.store(out_ptr2 + (r0_2 + 16*x0 + 16*x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp13, xmask) + tl.store(out_ptr3 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp14, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1 = args + args.clear() + s37 = arg0_1 + assert_size_stride(arg1_1, (8, ), (1, )) + with torch.cuda._DeviceGuard(4): + torch.cuda.set_device(4) + buf12 = empty_strided_cuda((8, 1, 16, 1 + ((127 + s37) // 128)), (16 + 16*((127 + s37) // 128), 16 + 16*((127 + s37) // 128), 1 + ((127 + s37) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [dense_mask_2], Original ATen: [aten.new_zeros] + triton_poi_fused_new_zeros_0_xnumel = 128 + 128*((127 + s37) // 128) + stream4 = get_raw_stream(4) + triton_poi_fused_new_zeros_0.run(buf12, triton_poi_fused_new_zeros_0_xnumel, stream=stream4) + buf19 = empty_strided_cuda((8, 1, 16, 1 + ((127 + s37) // 128)), (16 + 16*((127 + s37) // 128), 16 + 16*((127 + s37) // 128), 1 + ((127 + s37) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [dense_mask_4], Original ATen: [aten.new_zeros] + triton_poi_fused_new_zeros_0_xnumel = 128 + 128*((127 + s37) // 128) + stream4 = get_raw_stream(4) + triton_poi_fused_new_zeros_0.run(buf19, triton_poi_fused_new_zeros_0_xnumel, stream=stream4) + ps0 = (127 + s37) // 128 + ps1 = 16*((127 + s37) // 128) + buf1 = empty_strided_cuda((8, 1, 16, (127 + s37) // 128), (16*((127 + s37) // 128), 128*((127 + s37) // 128), (127 + s37) // 128, 1), torch.int32) + buf5 = empty_strided_cuda((8, 1, 16, (127 + s37) // 128), (16*((127 + s37) // 128), 128*((127 + s37) // 128), (127 + s37) // 128, 1), torch.int32) + # Topologically Sorted Source Nodes: [result_1, m, causal_mask, n, b, index, lt, padding_mask, index_1, lt_1, and_2, suffix_mask, remainder, index_2, padding_mask_1, and_3, and_4, sub, remainder_1, diagnol_mask, result_2, batched_outputs_2, mask_1, mask_2, mask_3, mask_block_sum, gt, lt_3, partial_blocks, partial_blocks_1, dense_mask, full_blocks, full_blocks_1, dense_mask_1], Original ATen: [aten.view, aten.arange, aten.ge, aten.index, aten.lt, aten.bitwise_and, aten.bitwise_or, aten.remainder, aten.sub, aten.eq, aten.constant_pad_nd, aten.permute, aten.sum, aten.gt, aten._to_copy] + triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1_xnumel = 128*((127 + s37) // 128) + stream4 = get_raw_stream(4) + triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1.run(arg1_1, buf1, buf5, ps0, s37, ps1, triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1_xnumel, 16384, stream=stream4) + del arg1_1 + # Topologically Sorted Source Nodes: [gt, lt_3, partial_blocks, partial_blocks_1, dense_mask, col_indices], Original ATen: [aten.gt, aten.lt, aten.bitwise_and, aten._to_copy, aten.sort] + buf2 = torch.ops.aten.sort.stable(buf1, stable=True, dim=3, descending=True) + buf4 = buf2[1] + assert_size_stride(buf4, (8, 1, 16, (127 + s37) // 128), (16*max(1, (127 + s37) // 128), 128*max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), 'torch.ops.aten.sort.stable') + assert_alignment(buf4, 16, 'torch.ops.aten.sort.stable') + del buf2 + buf10 = empty_strided_cuda((8, 1, 16), (16, 16, 1), torch.int32) + buf11 = empty_strided_cuda((8, 1, 16, (127 + s37) // 128), (16*max(1, (127 + s37) // 128), 16*max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [dense_mask_2, setitem, arange_4, row_indices, col_range, num_blocks_in_row, child_3, unsqueeze_1, index_mask, child_4, valid_indices], Original ATen: [aten.new_zeros, aten.arange, aten.unsqueeze, aten.sum, aten._to_copy, aten.lt, aten.scalar_tensor, aten.where, aten.view, aten.index_put] + triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2_r0_numel = (127 + s37) // 128 + stream4 = get_raw_stream(4) + triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2.run(buf1, buf4, buf10, buf11, buf12, ps0, s37, 128, triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2_r0_numel, stream=stream4) + del buf1 + del buf4 + buf26 = empty_strided_cuda((8, 1, (127 + s37) // 128, 16), (16*max(1, (127 + s37) // 128), 16*max(1, (127 + s37) // 128), 16, 1), torch.int32) + buf28 = empty_strided_cuda((8, 1, (127 + s37) // 128), (max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [batched_outputs_3, transpose, col_indices_2, q_indices, num_blocks_in_row_2, q_num_blocks], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sort, aten._to_copy, aten.sum] + triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3_xnumel = 8*((127 + s37) // 128) + stream4 = get_raw_stream(4) + triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3.run(buf12, buf26, buf28, ps0, triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3_xnumel, 16, stream=stream4) + del buf12 + # Topologically Sorted Source Nodes: [full_blocks, full_blocks_1, dense_mask_1, col_indices_1], Original ATen: [aten.eq, aten._to_copy, aten.sort] + buf6 = torch.ops.aten.sort.stable(buf5, stable=True, dim=3, descending=True) + buf8 = buf6[1] + assert_size_stride(buf8, (8, 1, 16, (127 + s37) // 128), (16*max(1, (127 + s37) // 128), 128*max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), 'torch.ops.aten.sort.stable') + assert_alignment(buf8, 16, 'torch.ops.aten.sort.stable') + del buf6 + buf17 = empty_strided_cuda((8, 1, 16), (16, 16, 1), torch.int32) + buf18 = empty_strided_cuda((8, 1, 16, (127 + s37) // 128), (16*max(1, (127 + s37) // 128), 16*max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [dense_mask_4, setitem_1, arange_6, row_indices_1, col_range_1, num_blocks_in_row_1, child_7, unsqueeze_3, index_mask_1, child_8, valid_indices_1], Original ATen: [aten.new_zeros, aten.arange, aten.unsqueeze, aten.sum, aten._to_copy, aten.lt, aten.scalar_tensor, aten.where, aten.view, aten.index_put] + triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2_r0_numel = (127 + s37) // 128 + stream4 = get_raw_stream(4) + triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2.run(buf5, buf8, buf17, buf18, buf19, ps0, s37, 128, triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2_r0_numel, stream=stream4) + del buf5 + del buf8 + buf23 = empty_strided_cuda((8, 1, (127 + s37) // 128, 16), (16*max(1, (127 + s37) // 128), 16*max(1, (127 + s37) // 128), 16, 1), torch.int32) + buf25 = empty_strided_cuda((8, 1, (127 + s37) // 128), (max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [batched_outputs_5, transpose_1, col_indices_3, full_q_indices, num_blocks_in_row_3, full_q_num_blocks], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sort, aten._to_copy, aten.sum] + triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3_xnumel = 8*((127 + s37) // 128) + stream4 = get_raw_stream(4) + triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3.run(buf19, buf23, buf25, ps0, triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3_xnumel, 16, stream=stream4) + del buf19 + return (buf23, buf25, buf26, buf28, buf18, buf17, buf11, buf10, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = 4096 + arg1_1 = rand_strided((8, ), (1, ), device='cuda:4', dtype=torch.int64) + fn = lambda: call([arg0_1, arg1_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/ge/cgeiheqll5tdjxehz3sguv4colaa2m4pazsls7jb5uiu7fwvsxkb.py b/SpecForge-ext/cache/compiled_kernels/ge/cgeiheqll5tdjxehz3sguv4colaa2m4pazsls7jb5uiu7fwvsxkb.py new file mode 100644 index 0000000000000000000000000000000000000000..2db4b32720978a726425a25d23b70dceaedb0959 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/ge/cgeiheqll5tdjxehz3sguv4colaa2m4pazsls7jb5uiu7fwvsxkb.py @@ -0,0 +1,352 @@ +# AOT ID: ['14_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/6b/c6beknosybos5d54llineldguuueh3kpjlkiuzm4pkorx7g6mjh6.py +# Topologically Sorted Source Nodes: [argmax], Original ATen: [aten.argmax] +# Source node to ATen node mapping: +# argmax => argmax +# Graph fragment: +# %arg1_1 : Tensor "bf16[2, s3, 32000][32000*s3, 32000, 1]cuda:0" = PlaceHolder[target=arg1_1] +# %argmax : Tensor "i64[2, s3][s3, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.argmax.default](args = (%arg1_1, -1), kwargs = {}) +# return %argmax +triton_red_fused_argmax_0 = async_compile.triton('triton_red_fused_argmax_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_argmax_0(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask & xmask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask & xmask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tl.store(out_ptr0 + (x0), tmp2, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/rm/crmfgeenggpe7hoot35x4eji7uf7h6kj6uq5zcsn2zuahh3agba4.py +# Topologically Sorted Source Nodes: [argmax_1], Original ATen: [aten.argmax] +# Source node to ATen node mapping: +# argmax_1 => argmax_1 +# Graph fragment: +# %arg3_1 : Tensor "f32[2, s3, 32000][s71, 32000, 1]cuda:0" = PlaceHolder[target=arg3_1] +# %argmax_1 : Tensor "i64[2, s3][s3, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.argmax.default](args = (%arg3_1, -1), kwargs = {}) +# return %argmax_1 +triton_red_fused_argmax_1 = async_compile.triton('triton_red_fused_argmax_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*i64', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_argmax_1(in_ptr0, out_ptr0, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % ks0) + x1 = xindex // ks0 + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + x3 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_2 + 32000*x0 + ks1*x1), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask & xmask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask & xmask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tl.store(out_ptr0 + (x3), tmp2, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/jx/cjxwalfxkks7buo2fu2ztx36ob736y24q4qdjpbqms4lwgsoo637.py +# Topologically Sorted Source Nodes: [eq, squeeze, mul, sum_1], Original ATen: [aten.eq, aten.squeeze, aten.mul, aten.sum] +# Source node to ATen node mapping: +# eq => eq_2 +# mul => mul_3 +# squeeze => squeeze +# sum_1 => sum_1 +# Graph fragment: +# %argmax : Tensor "i64[2, s3][s3, 1]cuda:0" = PlaceHolder[target=argmax] +# %argmax_1 : Tensor "i64[2, s3][s3, 1]cuda:0" = PlaceHolder[target=argmax_1] +# %arg4_1 : Tensor "i64[2, s3, 1][s3, 1, 1]cuda:0" = PlaceHolder[target=arg4_1] +# %eq_2 : Tensor "b8[2, s3][s3, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.eq.Tensor](args = (%argmax, %argmax_1), kwargs = {}) +# %squeeze : Tensor "i64[2, s3][s3, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%arg4_1, -1), kwargs = {}) +# %mul_3 : Tensor "i64[2, s3][s3, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%eq_2, %squeeze), kwargs = {}) +# %sum_1 : Tensor "i64[][]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_3,), kwargs = {}) +# return %sum_1 +triton_red_fused_eq_mul_squeeze_sum_2 = async_compile.triton('triton_red_fused_eq_mul_squeeze_sum_2', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 1, 'r0_': 4096}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i64', 'in_ptr2': '*i64', 'out_ptr0': '*i64', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {'xnumel': 1}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_eq_mul_squeeze_sum_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_eq_mul_squeeze_sum_2(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 1 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + _tmp7 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_0 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.load(in_ptr1 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp4 = tl.load(in_ptr2 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp2 = tmp0 == tmp1 + tmp3 = tmp2.to(tl.int64) + tmp5 = tmp3 * tmp4 + tmp6 = tl.broadcast_to(tmp5, [XBLOCK, R0_BLOCK]) + tmp8 = _tmp7 + tmp6 + _tmp7 = tl.where(r0_mask, tmp8, _tmp7) + tmp7 = tl.sum(_tmp7, 1)[:, None] + tl.store(out_ptr0 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp7, None) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/p4/cp4ijj7ba3ojx3goskrebxodtrzso4wryuwcis3dhm4ynbtr4x76.py +# Topologically Sorted Source Nodes: [sum_2, clamp_min, truediv], Original ATen: [aten.sum, aten.clamp_min, aten.div] +# Source node to ATen node mapping: +# clamp_min => clamp_min +# sum_2 => sum_2 +# truediv => div +# Graph fragment: +# %arg6_1 : Tensor "i64[2, s14, 1][s14, 1, 1]cuda:0" = PlaceHolder[target=arg6_1] +# %sum_1 : Tensor "i64[][]cuda:0" = PlaceHolder[target=sum_1] +# %sum_2 : Tensor "i64[][]cuda:0" = PlaceHolder[target=sum_2] +# %sum_2 : Tensor "i64[][]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%arg6_1,), kwargs = {}) +# %clamp_min : Tensor "f32[][]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clamp_min.default](args = (%sum_2, 1e-06), kwargs = {}) +# %div : Tensor "f32[][]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%sum_1, %clamp_min), kwargs = {}) +# return %sum_2,%div +triton_red_fused_clamp_min_div_sum_3 = async_compile.triton('triton_red_fused_clamp_min_div_sum_3', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 1, 'r0_': 4096}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i64', 'out_ptr1': '*fp32', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {'xnumel': 1}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_clamp_min_div_sum_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_clamp_min_div_sum_3(in_ptr0, in_ptr1, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 1 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + _tmp2 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_0 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp3 = _tmp2 + tmp1 + _tmp2 = tl.where(r0_mask, tmp3, _tmp2) + tmp2 = tl.sum(_tmp2, 1)[:, None] + tmp4 = tl.load(in_ptr1 + (0)) + tmp5 = tl.broadcast_to(tmp4, [XBLOCK, 1]) + tmp6 = tmp5.to(tl.float32) + tmp7 = tmp2.to(tl.float32) + tmp8 = 1e-06 + tmp9 = triton_helpers.maximum(tmp7, tmp8) + tmp10 = (tmp6 / tmp9) + tl.store(out_ptr1 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp10, None) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1 = args + args.clear() + s3 = arg0_1 + s71 = arg2_1 + s14 = arg5_1 + assert_size_stride(arg1_1, (2, s3, 32000), (32000*s3, 32000, 1)) + assert_size_stride(arg3_1, (2, s3, 32000), (s71, 32000, 1)) + assert_size_stride(arg4_1, (2, s3, 1), (s3, 1, 1)) + assert_size_stride(arg6_1, (2, s14, 1), (s14, 1, 1)) + with torch.cuda._DeviceGuard(0): + torch.cuda.set_device(0) + buf0 = empty_strided_cuda((2, s3), (s3, 1), torch.int64) + # Topologically Sorted Source Nodes: [argmax], Original ATen: [aten.argmax] + triton_red_fused_argmax_0_xnumel = 2*s3 + stream0 = get_raw_stream(0) + triton_red_fused_argmax_0.run(arg1_1, buf0, triton_red_fused_argmax_0_xnumel, 32000, stream=stream0) + del arg1_1 + buf1 = empty_strided_cuda((2, s3), (s3, 1), torch.int64) + # Topologically Sorted Source Nodes: [argmax_1], Original ATen: [aten.argmax] + triton_red_fused_argmax_1_xnumel = 2*s3 + stream0 = get_raw_stream(0) + triton_red_fused_argmax_1.run(arg3_1, buf1, s3, s71, triton_red_fused_argmax_1_xnumel, 32000, stream=stream0) + del arg3_1 + buf2 = empty_strided_cuda((), (), torch.int64) + # Topologically Sorted Source Nodes: [eq, squeeze, mul, sum_1], Original ATen: [aten.eq, aten.squeeze, aten.mul, aten.sum] + triton_red_fused_eq_mul_squeeze_sum_2_r0_numel = 2*s3 + stream0 = get_raw_stream(0) + triton_red_fused_eq_mul_squeeze_sum_2.run(buf0, buf1, arg4_1, buf2, 1, triton_red_fused_eq_mul_squeeze_sum_2_r0_numel, stream=stream0) + del arg4_1 + del buf0 + del buf1 + buf4 = empty_strided_cuda((), (), torch.float32) + # Topologically Sorted Source Nodes: [sum_2, clamp_min, truediv], Original ATen: [aten.sum, aten.clamp_min, aten.div] + triton_red_fused_clamp_min_div_sum_3_r0_numel = 2*s14 + stream0 = get_raw_stream(0) + triton_red_fused_clamp_min_div_sum_3.run(arg6_1, buf2, buf4, 1, triton_red_fused_clamp_min_div_sum_3_r0_numel, stream=stream0) + del arg6_1 + del buf2 + return (buf4, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = 1130 + arg1_1 = rand_strided((2, 1130, 32000), (36160000, 32000, 1), device='cuda:0', dtype=torch.bfloat16) + arg2_1 = 36384000 + arg3_1 = rand_strided((2, 1130, 32000), (36384000, 32000, 1), device='cuda:0', dtype=torch.float32) + arg4_1 = rand_strided((2, 1130, 1), (1130, 1, 1), device='cuda:0', dtype=torch.int64) + arg5_1 = 1130 + arg6_1 = rand_strided((2, 1130, 1), (1130, 1, 1), device='cuda:0', dtype=torch.int64) + fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/ge/cgeiqrn6fc2445boi46pfasru3dymyjiw2xhga6ztucscbgv3gtp.py b/SpecForge-ext/cache/compiled_kernels/ge/cgeiqrn6fc2445boi46pfasru3dymyjiw2xhga6ztucscbgv3gtp.py new file mode 100644 index 0000000000000000000000000000000000000000..70a44ad98a8cec6c77fb367ca91e930f20cd0290 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/ge/cgeiqrn6fc2445boi46pfasru3dymyjiw2xhga6ztucscbgv3gtp.py @@ -0,0 +1,63 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr2': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'add_persistent_rblock': True, 'tiling_scores': {'x': 0, 'r0_': 1310720000}} +) +@triton.jit +def triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0(in_ptr0, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 4096 + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp3_max = tl.full([XBLOCK, R0_BLOCK], float('-inf'), tl.float32) + _tmp3_sum = tl.zeros([XBLOCK, R0_BLOCK], tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp1 = tmp0.to(tl.float32) + tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) + + _tmp3_max_next, _tmp3_sum_next = triton_helpers.online_softmax_combine( + _tmp3_max, _tmp3_sum, tmp2, False + ) + + _tmp3_max = tl.where(r0_mask, _tmp3_max_next, _tmp3_max) + _tmp3_sum = tl.where(r0_mask, _tmp3_sum_next, _tmp3_sum) + + tmp3, tmp4 = triton_helpers.online_softmax_reduce( + _tmp3_max, _tmp3_sum, 1, False) + tmp3 = tmp3[:, None] + tmp4 = tmp4[:, None] + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp5 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp6 = tmp5.to(tl.float32) + tmp7 = tmp6 - tmp3 + tmp8 = libdevice.exp(tmp7) + tmp9 = (tmp8 / tmp4) + tl.store(out_ptr2 + (r0_1 + 32000*x0), tmp9, r0_mask) diff --git a/SpecForge-ext/cache/compiled_kernels/gj/81de88c34f763d66a846c9fad3cb18a15e21ec8558a3c392507c8a42067a9f48.best_config b/SpecForge-ext/cache/compiled_kernels/gj/81de88c34f763d66a846c9fad3cb18a15e21ec8558a3c392507c8a42067a9f48.best_config new file mode 100644 index 0000000000000000000000000000000000000000..73d39cec03a4913ffd38deb7ad038bf56b5cd33f --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/gj/81de88c34f763d66a846c9fad3cb18a15e21ec8558a3c392507c8a42067a9f48.best_config @@ -0,0 +1 @@ +{"XBLOCK": 1, "num_warps": 2, "num_stages": 1, "configs_hash": "b6ac5ef64fddcad8fc8d2c05fa12424871fd9baa5a4158ff38ecebbafb55a4b1", "found_by_coordesc": false, "time_taken_ms": 25, "triton_cache_hash": "G2LU7LIHIOEHQSWVLFBJATACJ76YHM672CUBUDGJGAJUEQVWVOFQ"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/gj/cgj33ylywpycbjnkt2jhtyk3565fuxbyhxodkhpafgpel3vyoq3v.py b/SpecForge-ext/cache/compiled_kernels/gj/cgj33ylywpycbjnkt2jhtyk3565fuxbyhxodkhpafgpel3vyoq3v.py new file mode 100644 index 0000000000000000000000000000000000000000..e1fadbdce24300db0d66627a9f9a8ddcee14167c --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/gj/cgj33ylywpycbjnkt2jhtyk3565fuxbyhxodkhpafgpel3vyoq3v.py @@ -0,0 +1,303 @@ +# AOT ID: ['7_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/vz/cvzofvv5xx3zbd3qsg6ytmxqt6aoybtka6jzw6fa2ybplg6reklt.py +# Topologically Sorted Source Nodes: [argmax], Original ATen: [aten.argmax] +# Source node to ATen node mapping: +# argmax => argmax +# Graph fragment: +# %arg0_1 : Tensor "bf16[2, 2048, 32000][65536000, 32000, 1]cuda:3" = PlaceHolder[target=arg0_1] +# %argmax : Tensor "i64[2, 2048][2048, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.argmax.default](args = (%arg0_1, -1), kwargs = {}) +# return %argmax +triton_red_fused_argmax_0 = async_compile.triton('triton_red_fused_argmax_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 65536, 'r0_': 262144000}} +) +@triton.jit +def triton_red_fused_argmax_0(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 4096 + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tl.store(out_ptr0 + (x0), tmp2, None) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/6z/c6zba2r22yyctp3hlaofoasgbhbtwqg7txp443m7rffkvpbhn34q.py +# Topologically Sorted Source Nodes: [argmax_1], Original ATen: [aten.argmax] +# Source node to ATen node mapping: +# argmax_1 => argmax_1 +# Graph fragment: +# %arg1_1 : Tensor "f32[2, 2048, 32000][65760000, 32000, 1]cuda:3" = PlaceHolder[target=arg1_1] +# %argmax_1 : Tensor "i64[2, 2048][2048, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.argmax.default](args = (%arg1_1, -1), kwargs = {}) +# return %argmax_1 +triton_red_fused_argmax_1 = async_compile.triton('triton_red_fused_argmax_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 65536, 'r0_': 524288000}} +) +@triton.jit +def triton_red_fused_argmax_1(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 4096 + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % 2048) + x1 = xindex // 2048 + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + x3 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_2 + 32000*x0 + 65760000*x1), r0_mask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tl.store(out_ptr0 + (x3), tmp2, None) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/vd/cvdnn2wxnzroeblfnctzs3gfd2zra7kmmbk3o2b3d473j3vqfrr5.py +# Topologically Sorted Source Nodes: [eq, squeeze, mul, sum_1, sum_2, clamp_min, truediv], Original ATen: [aten.eq, aten.squeeze, aten.mul, aten.sum, aten.clamp_min, aten.div] +# Source node to ATen node mapping: +# clamp_min => clamp_min +# eq => eq +# mul => mul +# squeeze => squeeze +# sum_1 => sum_1 +# sum_2 => sum_2 +# truediv => div +# Graph fragment: +# %argmax : Tensor "i64[2, 2048][2048, 1]cuda:3" = PlaceHolder[target=argmax] +# %argmax_1 : Tensor "i64[2, 2048][2048, 1]cuda:3" = PlaceHolder[target=argmax_1] +# %arg2_1 : Tensor "i64[2, 2048, 1][2048, 1, 1]cuda:3" = PlaceHolder[target=arg2_1] +# %arg3_1 : Tensor "i64[2, 2048, 1][2048, 1, 1]cuda:3" = PlaceHolder[target=arg3_1] +# %sum_1 : Tensor "i64[][]cuda:3" = PlaceHolder[target=sum_1] +# %sum_2 : Tensor "i64[][]cuda:3" = PlaceHolder[target=sum_2] +# %eq : Tensor "b8[2, 2048][2048, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.eq.Tensor](args = (%argmax, %argmax_1), kwargs = {}) +# %squeeze : Tensor "i64[2, 2048][2048, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%arg2_1, -1), kwargs = {}) +# %mul : Tensor "i64[2, 2048][2048, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%eq, %squeeze), kwargs = {}) +# %sum_1 : Tensor "i64[][]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul,), kwargs = {}) +# %sum_2 : Tensor "i64[][]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%arg3_1,), kwargs = {}) +# %clamp_min : Tensor "f32[][]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.clamp_min.default](args = (%sum_2, 1e-06), kwargs = {}) +# %div : Tensor "f32[][]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%sum_1, %clamp_min), kwargs = {}) +# return %sum_1,%sum_2,%div +triton_red_fused_clamp_min_div_eq_mul_squeeze_sum_2 = async_compile.triton('triton_red_fused_clamp_min_div_eq_mul_squeeze_sum_2', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 1, 'r0_': 4096}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i64', 'in_ptr2': '*i64', 'in_ptr3': '*i64', 'out_ptr2': '*fp32', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {'xnumel': 1}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_clamp_min_div_eq_mul_squeeze_sum_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'r0_': 131072}} +) +@triton.jit +def triton_red_fused_clamp_min_div_eq_mul_squeeze_sum_2(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 1 + r0_numel = 4096 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + _tmp7 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_0 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.load(in_ptr1 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp4 = tl.load(in_ptr2 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp2 = tmp0 == tmp1 + tmp3 = tmp2.to(tl.int64) + tmp5 = tmp3 * tmp4 + tmp6 = tl.broadcast_to(tmp5, [XBLOCK, R0_BLOCK]) + tmp8 = _tmp7 + tmp6 + _tmp7 = tl.where(r0_mask, tmp8, _tmp7) + tmp7 = tl.sum(_tmp7, 1)[:, None] + _tmp11 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_0 = r0_index + tmp9 = tl.load(in_ptr3 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp10 = tl.broadcast_to(tmp9, [XBLOCK, R0_BLOCK]) + tmp12 = _tmp11 + tmp10 + _tmp11 = tl.where(r0_mask, tmp12, _tmp11) + tmp11 = tl.sum(_tmp11, 1)[:, None] + tmp13 = tmp7.to(tl.float32) + tmp14 = tmp11.to(tl.float32) + tmp15 = 1e-06 + tmp16 = triton_helpers.maximum(tmp14, tmp15) + tmp17 = (tmp13 / tmp16) + tl.store(out_ptr2 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp17, None) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1, arg2_1, arg3_1 = args + args.clear() + assert_size_stride(arg0_1, (2, 2048, 32000), (65536000, 32000, 1)) + assert_size_stride(arg1_1, (2, 2048, 32000), (65760000, 32000, 1)) + assert_size_stride(arg2_1, (2, 2048, 1), (2048, 1, 1)) + assert_size_stride(arg3_1, (2, 2048, 1), (2048, 1, 1)) + with torch.cuda._DeviceGuard(3): + torch.cuda.set_device(3) + buf0 = empty_strided_cuda((2, 2048), (2048, 1), torch.int64) + # Topologically Sorted Source Nodes: [argmax], Original ATen: [aten.argmax] + stream3 = get_raw_stream(3) + triton_red_fused_argmax_0.run(arg0_1, buf0, 4096, 32000, stream=stream3) + del arg0_1 + buf1 = empty_strided_cuda((2, 2048), (2048, 1), torch.int64) + # Topologically Sorted Source Nodes: [argmax_1], Original ATen: [aten.argmax] + stream3 = get_raw_stream(3) + triton_red_fused_argmax_1.run(arg1_1, buf1, 4096, 32000, stream=stream3) + del arg1_1 + buf4 = empty_strided_cuda((), (), torch.float32) + # Topologically Sorted Source Nodes: [eq, squeeze, mul, sum_1, sum_2, clamp_min, truediv], Original ATen: [aten.eq, aten.squeeze, aten.mul, aten.sum, aten.clamp_min, aten.div] + stream3 = get_raw_stream(3) + triton_red_fused_clamp_min_div_eq_mul_squeeze_sum_2.run(buf0, buf1, arg2_1, arg3_1, buf4, 1, 4096, stream=stream3) + del arg2_1 + del arg3_1 + del buf0 + del buf1 + return (buf4, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = rand_strided((2, 2048, 32000), (65536000, 32000, 1), device='cuda:3', dtype=torch.bfloat16) + arg1_1 = rand_strided((2, 2048, 32000), (65760000, 32000, 1), device='cuda:3', dtype=torch.float32) + arg2_1 = rand_strided((2, 2048, 1), (2048, 1, 1), device='cuda:3', dtype=torch.int64) + arg3_1 = rand_strided((2, 2048, 1), (2048, 1, 1), device='cuda:3', dtype=torch.int64) + fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/gj/cgjwgvjl5tscnvomgnhypepg6kwritqye24dnb32wl5uai4wdynt.py b/SpecForge-ext/cache/compiled_kernels/gj/cgjwgvjl5tscnvomgnhypepg6kwritqye24dnb32wl5uai4wdynt.py new file mode 100644 index 0000000000000000000000000000000000000000..ef3e727d6ee03d7be8dc13f7f543099bd97719c8 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/gj/cgjwgvjl5tscnvomgnhypepg6kwritqye24dnb32wl5uai4wdynt.py @@ -0,0 +1,49 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 256, 'r0_': 16}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr2': '*i32', 'out_ptr3': '*i32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3(in_ptr0, out_ptr2, out_ptr3, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 16 + R0_BLOCK: tl.constexpr = 16 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + x0 = (xindex % ks0) + x1 = xindex // ks0 + x3 = xindex + tmp0 = tl.load(in_ptr0 + (r0_2 + x0 + 16*x1 + ks0*r0_2 + 16*ks0*x1), xmask, eviction_policy='evict_last', other=0.0) + tmp1 = r0_2 + tmp2 = tmp1.to(tl.int16) + tmp3 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp4 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5, tmp6, = triton_helpers.sort_with_index(tmp3, tmp4, None, 1, stable=True, descending=True) + tmp7 = tmp0.to(tl.int64) + tmp8 = tl.broadcast_to(tmp7, [XBLOCK, R0_BLOCK]) + tmp10 = tl.where(xmask, tmp8, 0) + tmp11 = tl.sum(tmp10, 1)[:, None].to(tl.int64) + tmp12 = tmp6.to(tl.int64) + tmp13 = tmp12.to(tl.int32) + tmp14 = tmp11.to(tl.int32) + tl.store(out_ptr2 + (r0_2 + 16*x0 + 16*x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp13, xmask) + tl.store(out_ptr3 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp14, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/gp/cgpqg54v7ag6awmgwhlrbbyw5jxsgjo6tuzvo3rt2xzqk6f33df2.py b/SpecForge-ext/cache/compiled_kernels/gp/cgpqg54v7ag6awmgwhlrbbyw5jxsgjo6tuzvo3rt2xzqk6f33df2.py new file mode 100644 index 0000000000000000000000000000000000000000..851a5c1c90285419de1438392f0c31dd0bcc2791 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/gp/cgpqg54v7ag6awmgwhlrbbyw5jxsgjo6tuzvo3rt2xzqk6f33df2.py @@ -0,0 +1,552 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'in_ptr9': '*i64', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 2097152, 262144, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 2097152, 262144, 128, 1 + + ZQ = 2 + HQ = 32 + Q_LEN = 2048 + ZKV = 2 + KV_LEN = 2048 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 2 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 256 + stride_kv_idx_m = 16 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 262144*idx_hq + 8388608*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m + 8388608*idx_zq, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr9 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/gr/cgrxjgtvhqzev7iiyqnuztmmz3pcmbdwcrbj25hftfrgrwd6xszx.py b/SpecForge-ext/cache/compiled_kernels/gr/cgrxjgtvhqzev7iiyqnuztmmz3pcmbdwcrbj25hftfrgrwd6xszx.py new file mode 100644 index 0000000000000000000000000000000000000000..051ef2504e5756ccbee4d77db265c619fa1d5cce --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/gr/cgrxjgtvhqzev7iiyqnuztmmz3pcmbdwcrbj25hftfrgrwd6xszx.py @@ -0,0 +1,527 @@ +# AOT ID: ['8_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ti/ctihvtg4pbsacqwjusips66jf62gkkpydmo3prc435glbmxzyjmy.py +# Topologically Sorted Source Nodes: [dense_mask_2], Original ATen: [aten.new_zeros] +# Source node to ATen node mapping: +# dense_mask_2 => full_default_1 +# Graph fragment: +# %full_default_1 : Tensor "i32[2, 1, 16, (((s37 + 127)//128)) + 1][16*Max(1, (((s37 + 127)//128)) + 1), 16*Max(1, (((s37 + 127)//128)) + 1), Max(1, (((s37 + 127)//128)) + 1), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 1, 16, %add_166], 0), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:4, pin_memory: False}) +# return %index_put +triton_poi_fused_new_zeros_0 = async_compile.triton('triton_poi_fused_new_zeros_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 2048}, + filename=__file__, + triton_meta={'signature': {'out_ptr0': '*i32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_new_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_new_zeros_0(out_ptr0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.full([1], 0, tl.int32) + tl.store(out_ptr0 + (x0), tmp0, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/gn/cgntwrk35eylt5m7jg6n3ah6b6hntq56lwmpg3vpgatqseurdonb.py +# Topologically Sorted Source Nodes: [result_1, m, causal_mask, n, b, index, lt, padding_mask, index_1, lt_1, and_2, suffix_mask, remainder, index_2, padding_mask_1, and_3, and_4, sub, remainder_1, diagnol_mask, result_2, batched_outputs_2, mask_1, mask_2, mask_3, mask_block_sum, gt, lt_3, partial_blocks, partial_blocks_1, dense_mask, full_blocks, full_blocks_1, dense_mask_1], Original ATen: [aten.view, aten.arange, aten.ge, aten.index, aten.lt, aten.bitwise_and, aten.bitwise_or, aten.remainder, aten.sub, aten.eq, aten.constant_pad_nd, aten.permute, aten.sum, aten.gt, aten._to_copy] +# Source node to ATen node mapping: +# and_2 => bitwise_and_1 +# and_3 => bitwise_and_2 +# and_4 => bitwise_and_3, view_8 +# b => iota +# batched_outputs_2 => view_9 +# causal_mask => ge_1, view +# dense_mask => convert_element_type_2 +# dense_mask_1 => convert_element_type_5 +# diagnol_mask => eq_12 +# full_blocks => eq_24 +# full_blocks_1 => convert_element_type_1 +# gt => gt +# index => index +# index_1 => index_1 +# index_2 => index_2 +# lt => lt, view_1 +# lt_1 => lt_1, view_2 +# lt_3 => lt_3 +# m => iota_2 +# mask_1 => constant_pad_nd +# mask_2 => view_10 +# mask_3 => permute +# mask_block_sum => sum_1 +# n => iota_3 +# padding_mask => bitwise_and, view_3, view_4 +# padding_mask_1 => lt_2, view_6 +# partial_blocks => bitwise_and_4 +# partial_blocks_1 => convert_element_type +# remainder => remainder +# remainder_1 => remainder_1 +# result_1 => bitwise_or, full_default +# result_2 => bitwise_or_1 +# sub => sub_12, view_7 +# suffix_mask => ge_2 +# Graph fragment: +# %arg1_1 : Tensor "i64[2][1]cuda:4" = PlaceHolder[target=arg1_1] +# %sum_1 : Tensor "i64[2, 1, 16, ((s37 + 127)//128)][16*(((s37 + 127)//128)), 32*(((s37 + 127)//128)), ((s37 + 127)//128), 1]cuda:4" = PlaceHolder[target=sum_1] +# %full_default : Tensor "b8[2, 1, 1][1, 1, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 1, 1], False), kwargs = {dtype: torch.bool, layout: torch.strided, device: cuda:4, pin_memory: False}) +# %iota_2 : Tensor "i64[2048][1]cuda:4"[num_users=3] = call_function[target=torch.ops.prims.iota.default](args = (2048,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:4, requires_grad: False}) +# %view : Tensor "i64[2048, 1][1, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%iota_2, [2048, 1]), kwargs = {}) +# %iota_3 : Tensor "i64[s37][1]cuda:4"[num_users=5] = call_function[target=torch.ops.prims.iota.default](args = (%arg0_1,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:4, requires_grad: False}) +# %ge_1 : Tensor "b8[2048, s37][Max(1, s37), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.ge.Tensor](args = (%view, %iota_3), kwargs = {}) +# %iota : Tensor "i64[2][1]cuda:4"[num_users=3] = call_function[target=torch.ops.prims.iota.default](args = (2,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:4, requires_grad: False}) +# %index : Tensor "i64[2][1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg1_1, [%iota]), kwargs = {}) +# %view_1 : Tensor "i64[2, 1][1, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%index, [2, 1]), kwargs = {}) +# %lt : Tensor "b8[2, s37][Max(1, s37), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_3, %view_1), kwargs = {}) +# %view_4 : Tensor "b8[2, 1, s37][Max(1, s37), s37, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%lt, [2, 1, %arg0_1]), kwargs = {}) +# %index_1 : Tensor "i64[2][1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg1_1, [%iota]), kwargs = {}) +# %view_2 : Tensor "i64[2, 1][1, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%index_1, [2, 1]), kwargs = {}) +# %lt_1 : Tensor "b8[2, 2048][2048, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_2, %view_2), kwargs = {}) +# %view_3 : Tensor "b8[2, 2048, 1][2048, 1, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%lt_1, [2, 2048, 1]), kwargs = {}) +# %bitwise_and : Tensor "b8[2, 2048, s37][2048*Max(1, s37), Max(1, s37), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%view_4, %view_3), kwargs = {}) +# %bitwise_and_1 : Tensor "b8[2, 2048, s37][2048*Max(1, s37), Max(1, s37), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%ge_1, %bitwise_and), kwargs = {}) +# %bitwise_or : Tensor "b8[2, 2048, s37][2048*Max(1, s37), Max(1, s37), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.bitwise_or.Tensor](args = (%full_default, %bitwise_and_1), kwargs = {}) +# %ge_2 : Tensor "b8[s37][1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.ge.Scalar](args = (%iota_3, 2048), kwargs = {}) +# %remainder : Tensor "i64[s37][1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.remainder.Scalar](args = (%iota_3, 2048), kwargs = {}) +# %index_2 : Tensor "i64[2][1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg1_1, [%iota]), kwargs = {}) +# %view_6 : Tensor "i64[2, 1][1, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%index_2, [2, 1]), kwargs = {}) +# %lt_2 : Tensor "b8[2, s37][Max(1, s37), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%remainder, %view_6), kwargs = {}) +# %bitwise_and_2 : Tensor "b8[2, s37][Max(1, s37), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%ge_2, %lt_2), kwargs = {}) +# %view_8 : Tensor "b8[2, 1, s37][Max(1, s37), s37, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%bitwise_and_2, [2, 1, %arg0_1]), kwargs = {}) +# %view_7 : Tensor "i64[2048, 1][1, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%iota_2, [2048, 1]), kwargs = {}) +# %sub_12 : Tensor "i64[2048, s37][Max(1, s37), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%iota_3, %view_7), kwargs = {}) +# %remainder_1 : Tensor "i64[2048, s37][Max(1, s37), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.remainder.Scalar](args = (%sub_12, 2048), kwargs = {}) +# %eq_12 : Tensor "b8[2048, s37][Max(1, s37), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%remainder_1, 0), kwargs = {}) +# %bitwise_and_3 : Tensor "b8[2, 2048, s37][2048*Max(1, s37), Max(1, s37), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%view_8, %eq_12), kwargs = {}) +# %bitwise_or_1 : Tensor "b8[2, 2048, s37][2048*Max(1, s37), Max(1, s37), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.bitwise_or.Tensor](args = (%bitwise_or, %bitwise_and_3), kwargs = {}) +# %view_9 : Tensor "b8[2, 1, 2048, s37][2048*Max(1, s37), 2048*Max(1, s37), Max(1, s37), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%bitwise_or_1, [2, 1, 2048, %arg0_1]), kwargs = {}) +# %constant_pad_nd : Tensor "b8[2, 1, 2048, 128*(((s37 + 127)//128))][2048*Max(1, 128*(((s37 + 127)//128))), 2048*Max(1, 128*(((s37 + 127)//128))), Max(1, 128*(((s37 + 127)//128))), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.constant_pad_nd.default](args = (%expand, [0, %sub_23, 0, 0], 0.0), kwargs = {}) +# %view_10 : Tensor "b8[2, 1, 16, 128, ((s37 + 127)//128), 128][2048*Max(1, 128*(((s37 + 127)//128))), 2048*Max(1, 128*(((s37 + 127)//128))), 128*Max(1, 128*(((s37 + 127)//128))), Max(1, 128*(((s37 + 127)//128))), 128, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%constant_pad_nd, [2, 1, 16, 128, %floordiv_1, 128]), kwargs = {}) +# %permute : Tensor "b8[2, 1, 16, ((s37 + 127)//128), 128, 128][2048*Max(1, 128*(((s37 + 127)//128))), 2048*Max(1, 128*(((s37 + 127)//128))), 128*Max(1, 128*(((s37 + 127)//128))), 128, Max(1, 128*(((s37 + 127)//128))), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_10, [0, 1, 2, 4, 3, 5]), kwargs = {}) +# %sum_1 : Tensor "i64[2, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:4"[num_users=3] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%permute, [-2, -1]), kwargs = {}) +# %gt : Tensor "b8[2, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {}) +# %lt_3 : Tensor "b8[2, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.lt.Scalar](args = (%sum_1, 16384), kwargs = {}) +# %bitwise_and_4 : Tensor "b8[2, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%gt, %lt_3), kwargs = {}) +# %convert_element_type : Tensor "i8[2, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%bitwise_and_4, torch.int8), kwargs = {}) +# %convert_element_type_2 : Tensor "i32[2, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:4"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%convert_element_type, torch.int32), kwargs = {}) +# %eq_24 : Tensor "b8[2, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%sum_1, 16384), kwargs = {}) +# %convert_element_type_1 : Tensor "i8[2, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%eq_24, torch.int8), kwargs = {}) +# %convert_element_type_5 : Tensor "i32[2, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:4"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%convert_element_type_1, torch.int32), kwargs = {}) +# return %sum_1,%convert_element_type_2,%convert_element_type_5 +triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1 = async_compile.triton('triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 1024, 'r0_': 16384}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr1': '*i32', 'out_ptr2': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1(in_ptr0, out_ptr1, out_ptr2, ks0, ks1, ks2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 16384 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % ks0) + x1 = ((xindex // ks0) % 16) + x2 = xindex // ks2 + _tmp36 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + x5 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_3 = (r0_index % 128) + r0_4 = r0_index // 128 + tmp0 = r0_3 + 128*x0 + tmp1 = ks1 + tmp2 = tmp0 < tmp1 + tmp3 = r0_4 + 128*x1 + tmp4 = r0_3 + 128*x0 + tmp5 = tmp3 >= tmp4 + tmp6 = tl.load(in_ptr0 + (tl.broadcast_to(x2, [XBLOCK, R0_BLOCK])), r0_mask & tmp2 & xmask, eviction_policy='evict_last', other=0.0) + tmp7 = tmp4 < tmp6 + tmp8 = tmp3 < tmp6 + tmp9 = tmp7 & tmp8 + tmp10 = tmp5 & tmp9 + tmp11 = tl.full([1, 1], False, tl.int1) + tmp12 = tmp11 | tmp10 + tmp13 = tl.full([1, 1], 2048, tl.int64) + tmp14 = tmp4 >= tmp13 + tmp15 = ((r0_3 + 128*x0) % 2048) + tmp16 = tmp15 < tmp6 + tmp17 = tmp14 & tmp16 + tmp18 = r0_3 + ((-1)*r0_4) + ((-128)*x1) + 128*x0 + tmp19 = (tmp18 % tmp13) + tmp20 = tl.full([1, 1], 0, tl.int32) + tmp21 = tmp19 != tmp20 + tmp22 = (libdevice.signbit(tmp19) != 0) if (tmp19).dtype is tl.float32 else tmp19 < 0 + tmp23 = (libdevice.signbit(tmp13) != 0) if (tmp13).dtype is tl.float32 else tmp13 < 0 + tmp24 = tmp22 != tmp23 + tmp25 = tmp21 & tmp24 + tmp26 = tmp19 + tmp13 + tmp27 = tl.where(tmp25, tmp26, tmp19) + tmp28 = tl.full([1, 1], 0, tl.int64) + tmp29 = tmp27 == tmp28 + tmp30 = tmp17 & tmp29 + tmp31 = tmp12 | tmp30 + tmp32 = tl.full(tmp31.shape, False, tmp31.dtype) + tmp33 = tl.where(tmp2, tmp31, tmp32) + tmp34 = tmp33.to(tl.int64) + tmp35 = tl.broadcast_to(tmp34, [XBLOCK, R0_BLOCK]) + tmp37 = _tmp36 + tmp35 + _tmp36 = tl.where(r0_mask & xmask, tmp37, _tmp36) + tmp36 = tl.sum(_tmp36, 1)[:, None] + tmp38 = tl.full([1, 1], 0, tl.int64) + tmp39 = tmp36 > tmp38 + tmp40 = tl.full([1, 1], 16384, tl.int64) + tmp41 = tmp36 < tmp40 + tmp42 = tmp39 & tmp41 + tmp43 = tmp42.to(tl.int8) + tmp44 = tmp43.to(tl.int32) + tmp45 = tmp36 == tmp40 + tmp46 = tmp45.to(tl.int8) + tmp47 = tmp46.to(tl.int32) + tl.store(out_ptr1 + (x5), tmp44, xmask) + tl.store(out_ptr2 + (x5), tmp47, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/d4/cd4dkr5xpb65gluxif27ifpd5dlgspgfoncrldfk4wsiw757dc6j.py +# Topologically Sorted Source Nodes: [dense_mask_2, setitem, arange_4, row_indices, col_range, num_blocks_in_row, child_3, unsqueeze_1, index_mask, child_4, valid_indices], Original ATen: [aten.new_zeros, aten.arange, aten.unsqueeze, aten.sum, aten._to_copy, aten.lt, aten.scalar_tensor, aten.where, aten.view, aten.index_put] +# Source node to ATen node mapping: +# arange_4 => iota_4 +# child_3 => convert_element_type_3 +# child_4 => convert_element_type_4 +# col_range => iota_5 +# dense_mask_2 => full_default_1 +# index_mask => lt_4 +# num_blocks_in_row => sum_2 +# row_indices => unsqueeze +# setitem => full_default_2, index_put, iota_6, iota_7, unsqueeze_2, unsqueeze_3, unsqueeze_4, unsqueeze_5, unsqueeze_6 +# unsqueeze_1 => unsqueeze_1 +# valid_indices => scalar_tensor, where +# Graph fragment: +# %convert_element_type_2 : Tensor "i32[2, 1, 16, ((s37 + 127)//128)][16*(((s37 + 127)//128)), 32*(((s37 + 127)//128)), ((s37 + 127)//128), 1]cuda:4" = PlaceHolder[target=convert_element_type_2] +# %sum_2 : Tensor "i64[2, 1, 16][16, 32, 1]cuda:4" = PlaceHolder[target=sum_2] +# %getitem_1 : Tensor "i64[2, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 32*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:4" = PlaceHolder[target=getitem_1] +# %convert_element_type_3 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:4" = PlaceHolder[target=convert_element_type_3] +# %convert_element_type_4 : Tensor "i32[2, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:4" = PlaceHolder[target=convert_element_type_4] +# %index_put : Tensor "i32[2, 1, 16, (((s37 + 127)//128)) + 1][16*(((s37 + 127)//128)) + 16, 16*(((s37 + 127)//128)) + 16, (((s37 + 127)//128)) + 1, 1]cuda:4" = PlaceHolder[target=index_put] +# %full_default_1 : Tensor "i32[2, 1, 16, (((s37 + 127)//128)) + 1][16*Max(1, (((s37 + 127)//128)) + 1), 16*Max(1, (((s37 + 127)//128)) + 1), Max(1, (((s37 + 127)//128)) + 1), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 1, 16, %add_166], 0), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:4, pin_memory: False}) +# %iota_7 : Tensor "i64[2][1]cuda:4"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (2,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:4, requires_grad: False}) +# %unsqueeze_4 : Tensor "i64[2, 1][1, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_7, -1), kwargs = {}) +# %unsqueeze_5 : Tensor "i64[2, 1, 1][1, 1, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_4, -1), kwargs = {}) +# %unsqueeze_6 : Tensor "i64[2, 1, 1, 1][1, 1, 1, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_5, -1), kwargs = {}) +# %iota_6 : Tensor "i64[1][1]cuda:4"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (1,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:4, requires_grad: False}) +# %unsqueeze_2 : Tensor "i64[1, 1][1, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_6, -1), kwargs = {}) +# %unsqueeze_3 : Tensor "i64[1, 1, 1][1, 1, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_2, -1), kwargs = {}) +# %iota_4 : Tensor "i32[16][1]cuda:4"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (16,), kwargs = {start: 0, step: 1, dtype: torch.int32, device: cuda:4, requires_grad: False}) +# %unsqueeze : Tensor "i32[16, 1][1, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_4, -1), kwargs = {}) +# %iota_5 : Tensor "i32[((s37 + 127)//128)][1]cuda:4"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (%floordiv_1,), kwargs = {start: 0, step: 1, dtype: torch.int32, device: cuda:4, requires_grad: False}) +# %sum_2 : Tensor "i64[2, 1, 16][16, 16, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%convert_element_type_2, [-1]), kwargs = {}) +# %convert_element_type_3 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:4"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_2, torch.int32), kwargs = {}) +# %unsqueeze_1 : Tensor "i32[2, 1, 16, 1][16, 16, 1, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_3, 3), kwargs = {}) +# %lt_4 : Tensor "b8[2, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_5, %unsqueeze_1), kwargs = {}) +# %convert_element_type_4 : Tensor "i32[2, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:4"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_1, torch.int32), kwargs = {}) +# %scalar_tensor : Tensor "i32[][]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.scalar_tensor.default](args = (%floordiv_1,), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:4}) +# %where : Tensor "i32[2, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%lt_4, %convert_element_type_4, %scalar_tensor), kwargs = {}) +# %full_default_2 : Tensor "i32[2, 1, 1, 1][1, 1, 1, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 1, 1, 1], 1), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:4, pin_memory: False}) +# %index_put : Tensor "i32[2, 1, 16, (((s37 + 127)//128)) + 1][16*Max(1, (((s37 + 127)//128)) + 1), 16*Max(1, (((s37 + 127)//128)) + 1), Max(1, (((s37 + 127)//128)) + 1), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.index_put_.default](args = (%full_default_1, [%unsqueeze_6, %unsqueeze_3, %unsqueeze, %where], %full_default_2), kwargs = {}) +# return %sum_2,%convert_element_type_3,%convert_element_type_4,%buf13 +triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2 = async_compile.triton('triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 32, 'r0_': 32}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'in_ptr1': '*i64', 'out_ptr1': '*i32', 'out_ptr2': '*i32', 'out_ptr3': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2', 'mutated_arg_names': ['out_ptr3'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2(in_ptr0, in_ptr1, out_ptr1, out_ptr2, out_ptr3, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 32 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp3 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp1 = tmp0.to(tl.int64) + tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) + tmp4 = _tmp3 + tmp2 + _tmp3 = tl.where(r0_mask & xmask, tmp4, _tmp3) + tmp3 = tl.sum(_tmp3, 1)[:, None] + tmp5 = tmp3.to(tl.int32) + tl.store(out_ptr1 + (x0), tmp5, xmask) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp6 = tl.load(in_ptr1 + (r0_1 + x0*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp7 = tmp6.to(tl.int32) + tmp8 = r0_1 + tmp9 = tmp8 < tmp5 + tmp10 = ks0 + tmp11 = tl.where(tmp9, tmp7, tmp10) + tmp12 = 1 + ks0 + tmp13 = tmp11 + tmp12 + tmp14 = tmp11 < 0 + tmp15 = tl.where(tmp14, tmp13, tmp11) + tl.device_assert(((0 <= tmp15) & (tmp15 < 1 + (triton_helpers.div_floor_integer(127 + ks1, 128)))) | ~(r0_mask & xmask), "index out of bounds: 0 <= tmp15 < 1 + (triton_helpers.div_floor_integer(127 + ks1, 128))") + tmp17 = tl.full([1, 1], 1, tl.int32) + tl.store(out_ptr2 + (r0_1 + x0*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp7, r0_mask & xmask) + tl.store(out_ptr3 + (tl.broadcast_to(tmp15 + x0 + ks0*x0, [XBLOCK, R0_BLOCK])), tmp17, r0_mask & xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/mx/cmxtehlygzy2cjddulwsvjghigetqtozdl5ft6qfk3edunt3obku.py +# Topologically Sorted Source Nodes: [batched_outputs_3, transpose, col_indices_2, q_indices, num_blocks_in_row_2, q_num_blocks], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sort, aten._to_copy, aten.sum] +# Source node to ATen node mapping: +# batched_outputs_3 => clone_4, slice_2 +# col_indices_2 => sort_2 +# num_blocks_in_row_2 => sum_4 +# q_indices => clone_6, convert_element_type_9 +# q_num_blocks => convert_element_type_8 +# transpose => permute_1 +# Graph fragment: +# %buf13 : Tensor "i32[2, 1, 16, (((s37 + 127)//128)) + 1][16*(((s37 + 127)//128)) + 16, 16*(((s37 + 127)//128)) + 16, (((s37 + 127)//128)) + 1, 1]cuda:4" = PlaceHolder[target=buf13] +# %buf15 : Tensor "i16[2, 1, ((s37 + 127)//128), 16][16*(((s37 + 127)//128)), 32*(((s37 + 127)//128)), 16, 1]cuda:4" = PlaceHolder[target=buf15] +# %sum_4 : Tensor "i64[2, 1, ((s37 + 127)//128)][((s37 + 127)//128), 2*(((s37 + 127)//128)), 1]cuda:4" = PlaceHolder[target=sum_4] +# %slice_2 : Tensor "i32[2, 1, 16, ((s37 + 127)//128)][16*Max(1, (((s37 + 127)//128)) + 1), 16*Max(1, (((s37 + 127)//128)) + 1), Max(1, (((s37 + 127)//128)) + 1), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%index_put, 3, 0, %floordiv_1), kwargs = {}) +# %clone_4 : Tensor "i32[2, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%slice_2,), kwargs = {memory_format: torch.contiguous_format}) +# %permute_1 : Tensor "i32[2, 1, ((s37 + 127)//128), 16][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), 1, Max(1, ((s37 + 127)//128))]cuda:4"[num_users=2] = call_function[target=torch.ops.aten.permute.default](args = (%clone_4, [0, 1, 3, 2]), kwargs = {}) +# %sort_2 : [num_users=1] = call_function[target=torch.ops.aten.sort.stable](args = (%permute_1,), kwargs = {stable: True, descending: True}) +# %convert_element_type_9 : Tensor "i32[2, 1, ((s37 + 127)//128), 16][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), 1, Max(1, ((s37 + 127)//128))]cuda:4"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_5, torch.int32), kwargs = {}) +# %clone_6 : Tensor "i32[2, 1, ((s37 + 127)//128), 16][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), 16, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%convert_element_type_9,), kwargs = {memory_format: torch.contiguous_format}) +# %sum_4 : Tensor "i64[2, 1, ((s37 + 127)//128)][Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%permute_1, [-1]), kwargs = {}) +# %convert_element_type_8 : Tensor "i32[2, 1, ((s37 + 127)//128)][Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_4, torch.int32), kwargs = {}) +# return %buf15,%sum_4,%clone_6,%convert_element_type_8 +triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3 = async_compile.triton('triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 64, 'r0_': 16}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr2': '*i32', 'out_ptr3': '*i32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3(in_ptr0, out_ptr2, out_ptr3, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 16 + R0_BLOCK: tl.constexpr = 16 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + x0 = (xindex % ks0) + x1 = xindex // ks0 + x3 = xindex + tmp0 = tl.load(in_ptr0 + (r0_2 + x0 + 16*x1 + ks0*r0_2 + 16*ks0*x1), xmask, eviction_policy='evict_last', other=0.0) + tmp1 = r0_2 + tmp2 = tmp1.to(tl.int16) + tmp3 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp4 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5, tmp6, = triton_helpers.sort_with_index(tmp3, tmp4, None, 1, stable=True, descending=True) + tmp7 = tmp0.to(tl.int64) + tmp8 = tl.broadcast_to(tmp7, [XBLOCK, R0_BLOCK]) + tmp10 = tl.where(xmask, tmp8, 0) + tmp11 = tl.sum(tmp10, 1)[:, None].to(tl.int64) + tmp12 = tmp6.to(tl.int64) + tmp13 = tmp12.to(tl.int32) + tmp14 = tmp11.to(tl.int32) + tl.store(out_ptr2 + (r0_2 + 16*x0 + 16*x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp13, xmask) + tl.store(out_ptr3 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp14, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1 = args + args.clear() + s37 = arg0_1 + assert_size_stride(arg1_1, (2, ), (1, )) + with torch.cuda._DeviceGuard(4): + torch.cuda.set_device(4) + buf12 = empty_strided_cuda((2, 1, 16, 1 + ((127 + s37) // 128)), (16 + 16*((127 + s37) // 128), 16 + 16*((127 + s37) // 128), 1 + ((127 + s37) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [dense_mask_2], Original ATen: [aten.new_zeros] + triton_poi_fused_new_zeros_0_xnumel = 32 + 32*((127 + s37) // 128) + stream4 = get_raw_stream(4) + triton_poi_fused_new_zeros_0.run(buf12, triton_poi_fused_new_zeros_0_xnumel, stream=stream4) + buf19 = empty_strided_cuda((2, 1, 16, 1 + ((127 + s37) // 128)), (16 + 16*((127 + s37) // 128), 16 + 16*((127 + s37) // 128), 1 + ((127 + s37) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [dense_mask_4], Original ATen: [aten.new_zeros] + triton_poi_fused_new_zeros_0_xnumel = 32 + 32*((127 + s37) // 128) + stream4 = get_raw_stream(4) + triton_poi_fused_new_zeros_0.run(buf19, triton_poi_fused_new_zeros_0_xnumel, stream=stream4) + ps0 = (127 + s37) // 128 + ps1 = 16*((127 + s37) // 128) + buf1 = empty_strided_cuda((2, 1, 16, (127 + s37) // 128), (16*((127 + s37) // 128), 32*((127 + s37) // 128), (127 + s37) // 128, 1), torch.int32) + buf5 = empty_strided_cuda((2, 1, 16, (127 + s37) // 128), (16*((127 + s37) // 128), 32*((127 + s37) // 128), (127 + s37) // 128, 1), torch.int32) + # Topologically Sorted Source Nodes: [result_1, m, causal_mask, n, b, index, lt, padding_mask, index_1, lt_1, and_2, suffix_mask, remainder, index_2, padding_mask_1, and_3, and_4, sub, remainder_1, diagnol_mask, result_2, batched_outputs_2, mask_1, mask_2, mask_3, mask_block_sum, gt, lt_3, partial_blocks, partial_blocks_1, dense_mask, full_blocks, full_blocks_1, dense_mask_1], Original ATen: [aten.view, aten.arange, aten.ge, aten.index, aten.lt, aten.bitwise_and, aten.bitwise_or, aten.remainder, aten.sub, aten.eq, aten.constant_pad_nd, aten.permute, aten.sum, aten.gt, aten._to_copy] + triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1_xnumel = 32*((127 + s37) // 128) + stream4 = get_raw_stream(4) + triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1.run(arg1_1, buf1, buf5, ps0, s37, ps1, triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1_xnumel, 16384, stream=stream4) + del arg1_1 + # Topologically Sorted Source Nodes: [gt, lt_3, partial_blocks, partial_blocks_1, dense_mask, col_indices], Original ATen: [aten.gt, aten.lt, aten.bitwise_and, aten._to_copy, aten.sort] + buf2 = torch.ops.aten.sort.stable(buf1, stable=True, dim=3, descending=True) + buf4 = buf2[1] + assert_size_stride(buf4, (2, 1, 16, (127 + s37) // 128), (16*max(1, (127 + s37) // 128), 32*max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), 'torch.ops.aten.sort.stable') + assert_alignment(buf4, 16, 'torch.ops.aten.sort.stable') + del buf2 + buf10 = empty_strided_cuda((2, 1, 16), (16, 16, 1), torch.int32) + buf11 = empty_strided_cuda((2, 1, 16, (127 + s37) // 128), (16*max(1, (127 + s37) // 128), 16*max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [dense_mask_2, setitem, arange_4, row_indices, col_range, num_blocks_in_row, child_3, unsqueeze_1, index_mask, child_4, valid_indices], Original ATen: [aten.new_zeros, aten.arange, aten.unsqueeze, aten.sum, aten._to_copy, aten.lt, aten.scalar_tensor, aten.where, aten.view, aten.index_put] + triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2_r0_numel = (127 + s37) // 128 + stream4 = get_raw_stream(4) + triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2.run(buf1, buf4, buf10, buf11, buf12, ps0, s37, 32, triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2_r0_numel, stream=stream4) + del buf1 + del buf4 + buf26 = empty_strided_cuda((2, 1, (127 + s37) // 128, 16), (16*max(1, (127 + s37) // 128), 16*max(1, (127 + s37) // 128), 16, 1), torch.int32) + buf28 = empty_strided_cuda((2, 1, (127 + s37) // 128), (max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [batched_outputs_3, transpose, col_indices_2, q_indices, num_blocks_in_row_2, q_num_blocks], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sort, aten._to_copy, aten.sum] + triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3_xnumel = 2*((127 + s37) // 128) + stream4 = get_raw_stream(4) + triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3.run(buf12, buf26, buf28, ps0, triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3_xnumel, 16, stream=stream4) + del buf12 + # Topologically Sorted Source Nodes: [full_blocks, full_blocks_1, dense_mask_1, col_indices_1], Original ATen: [aten.eq, aten._to_copy, aten.sort] + buf6 = torch.ops.aten.sort.stable(buf5, stable=True, dim=3, descending=True) + buf8 = buf6[1] + assert_size_stride(buf8, (2, 1, 16, (127 + s37) // 128), (16*max(1, (127 + s37) // 128), 32*max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), 'torch.ops.aten.sort.stable') + assert_alignment(buf8, 16, 'torch.ops.aten.sort.stable') + del buf6 + buf17 = empty_strided_cuda((2, 1, 16), (16, 16, 1), torch.int32) + buf18 = empty_strided_cuda((2, 1, 16, (127 + s37) // 128), (16*max(1, (127 + s37) // 128), 16*max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [dense_mask_4, setitem_1, arange_6, row_indices_1, col_range_1, num_blocks_in_row_1, child_7, unsqueeze_3, index_mask_1, child_8, valid_indices_1], Original ATen: [aten.new_zeros, aten.arange, aten.unsqueeze, aten.sum, aten._to_copy, aten.lt, aten.scalar_tensor, aten.where, aten.view, aten.index_put] + triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2_r0_numel = (127 + s37) // 128 + stream4 = get_raw_stream(4) + triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2.run(buf5, buf8, buf17, buf18, buf19, ps0, s37, 32, triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2_r0_numel, stream=stream4) + del buf5 + del buf8 + buf23 = empty_strided_cuda((2, 1, (127 + s37) // 128, 16), (16*max(1, (127 + s37) // 128), 16*max(1, (127 + s37) // 128), 16, 1), torch.int32) + buf25 = empty_strided_cuda((2, 1, (127 + s37) // 128), (max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [batched_outputs_5, transpose_1, col_indices_3, full_q_indices, num_blocks_in_row_3, full_q_num_blocks], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sort, aten._to_copy, aten.sum] + triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3_xnumel = 2*((127 + s37) // 128) + stream4 = get_raw_stream(4) + triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3.run(buf19, buf23, buf25, ps0, triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3_xnumel, 16, stream=stream4) + del buf19 + return (buf23, buf25, buf26, buf28, buf18, buf17, buf11, buf10, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = 4096 + arg1_1 = rand_strided((2, ), (1, ), device='cuda:4', dtype=torch.int64) + fn = lambda: call([arg0_1, arg1_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/i7/753f096bb7aa5e1fecb278d28fbe292a2ea22998bc17497e9b0f2de87744233b.best_config b/SpecForge-ext/cache/compiled_kernels/i7/753f096bb7aa5e1fecb278d28fbe292a2ea22998bc17497e9b0f2de87744233b.best_config new file mode 100644 index 0000000000000000000000000000000000000000..b9c83cd70cc4f7d46eca037549afe001d843ad6c --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/i7/753f096bb7aa5e1fecb278d28fbe292a2ea22998bc17497e9b0f2de87744233b.best_config @@ -0,0 +1 @@ +{"XBLOCK": 512, "num_warps": 8, "num_stages": 1, "configs_hash": "3ca5c3e34d35093f3c9ab2829a9faeebad5e61c4ca13d5ed6053d7b71ce60d5a", "found_by_coordesc": false, "time_taken_ms": 49, "triton_cache_hash": "Z2RWAHMO7VUWQKIIRA5A46JYV2SEXHWLKREQM7TOP6VGUWDXAYAQ"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/i7/ci72wbomeqiqrinpf2bqkd3bkzlans6x5wsg36itkz6xlzcsenoc.py b/SpecForge-ext/cache/compiled_kernels/i7/ci72wbomeqiqrinpf2bqkd3bkzlans6x5wsg36itkz6xlzcsenoc.py new file mode 100644 index 0000000000000000000000000000000000000000..6bb5bbf62d51727927d39fab4c9a9e37f8e6cce2 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/i7/ci72wbomeqiqrinpf2bqkd3bkzlans6x5wsg36itkz6xlzcsenoc.py @@ -0,0 +1,56 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 16777216}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'ks4': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, ks4, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x4 = xindex + x2 = ((xindex // ks0) % ks1) + x0 = (xindex % ks3) + x5 = xindex // ks3 + tmp0 = tl.load(in_ptr0 + (x4), xmask, eviction_policy='evict_last').to(tl.float32) + tmp1 = tl.load(in_ptr1 + (x2), xmask, eviction_policy='evict_last') + tmp2 = ks2 + tmp3 = tmp1 + tmp2 + tmp4 = tmp1 < 0 + tmp5 = tl.where(tmp4, tmp3, tmp1) + tl.device_assert(((0 <= tmp5) & (tmp5 < ks2)) | ~(xmask), "index out of bounds: 0 <= tmp5 < ks2") + tmp7 = tl.load(in_ptr2 + (x0 + ks3*tmp5), xmask, eviction_policy='evict_last').to(tl.float32) + tmp8 = tmp0 * tmp7 + tmp9 = x0 + tmp10 = tl.full([1], 0, tl.int64) + tmp11 = tmp9 >= tmp10 + tmp12 = ks3 + (-1)*(ks3 // 2) + tmp13 = tmp9 < tmp12 + tmp14 = tl.load(in_ptr0 + (ks3*x5 + (ks3 // 2) + (x0)), tmp13 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp15 = -tmp14 + tmp16 = tl.full(tmp15.shape, 0.0, tmp15.dtype) + tmp17 = tl.where(tmp13, tmp15, tmp16) + tmp18 = tmp9 >= tmp12 + tmp19 = ks3 + tmp20 = tmp9 < tmp19 + tmp21 = tl.load(in_ptr0 + (ks3*x5 + (x0 + ((-1)*ks3) + (ks3 // 2))), tmp18 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp22 = tl.where(tmp13, tmp17, tmp21) + tmp23 = ks4 + tmp24 = tmp1 + tmp23 + tmp25 = tl.where(tmp4, tmp24, tmp1) + tl.device_assert(((0 <= tmp25) & (tmp25 < ks4)) | ~(xmask), "index out of bounds: 0 <= tmp25 < ks4") + tmp27 = tl.load(in_ptr3 + (x0 + ks3*tmp25), xmask, eviction_policy='evict_last').to(tl.float32) + tmp28 = tmp22 * tmp27 + tmp29 = tmp8 + tmp28 + tl.store(out_ptr0 + (x4), tmp29, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/i7/ci7bv4n3qd2yby5fwnpfmwqhmlmbx5hqe2x24h5oherh4kv76un4.py b/SpecForge-ext/cache/compiled_kernels/i7/ci7bv4n3qd2yby5fwnpfmwqhmlmbx5hqe2x24h5oherh4kv76un4.py new file mode 100644 index 0000000000000000000000000000000000000000..fcb174545c30391373d71ba03532046f6d734195 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/i7/ci7bv4n3qd2yby5fwnpfmwqhmlmbx5hqe2x24h5oherh4kv76un4.py @@ -0,0 +1,416 @@ +# AOT ID: ['14_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/mj/cmjlfojpnpm5jni2ravb3komgycjc5mn3sbp2hi3ttso25z44mlc.py +# Topologically Sorted Source Nodes: [argmax], Original ATen: [aten.argmax] +# Source node to ATen node mapping: +# argmax => argmax +# Graph fragment: +# %arg1_1 : Tensor "bf16[8, s3, 32000][32000*s3, 32000, 1]cuda:4" = PlaceHolder[target=arg1_1] +# %argmax : Tensor "i64[8, s3][s3, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.argmax.default](args = (%arg1_1, -1), kwargs = {}) +# return %argmax +triton_red_fused_argmax_0 = async_compile.triton('triton_red_fused_argmax_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 16384, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_argmax_0(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask & xmask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask & xmask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tl.store(out_ptr0 + (x0), tmp2, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/am/camtlvdra4wkjpnusgr2wvtfxaqcnp25a5th4hbccdawibyl2rt3.py +# Topologically Sorted Source Nodes: [argmax_1], Original ATen: [aten.argmax] +# Source node to ATen node mapping: +# argmax_1 => argmax_1 +# Graph fragment: +# %arg3_1 : Tensor "f32[8, s3, 32000][s71, 32000, 1]cuda:4" = PlaceHolder[target=arg3_1] +# %argmax_1 : Tensor "i64[8, s3][s3, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.argmax.default](args = (%arg3_1, -1), kwargs = {}) +# return %argmax_1 +triton_red_fused_argmax_1 = async_compile.triton('triton_red_fused_argmax_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 16384, 'r0_': 32768}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*i64', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_argmax_1(in_ptr0, out_ptr0, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % ks0) + x1 = xindex // ks0 + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + x3 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_2 + 32000*x0 + ks1*x1), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask & xmask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask & xmask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tl.store(out_ptr0 + (x3), tmp2, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/mr/cmrf526mm6su5xqnmne5okkpp4fxut73afx62yxlvlmbr6yjqxen.py +# Topologically Sorted Source Nodes: [eq, squeeze, mul, sum_1], Original ATen: [aten.eq, aten.squeeze, aten.mul, aten.sum] +# Source node to ATen node mapping: +# eq => eq_2 +# mul => mul_7 +# squeeze => squeeze +# sum_1 => sum_1 +# Graph fragment: +# %argmax : Tensor "i64[8, s3][s3, 1]cuda:4" = PlaceHolder[target=argmax] +# %argmax_1 : Tensor "i64[8, s3][s3, 1]cuda:4" = PlaceHolder[target=argmax_1] +# %arg4_1 : Tensor "i64[8, s3, 1][s3, 1, 1]cuda:4" = PlaceHolder[target=arg4_1] +# %eq_2 : Tensor "b8[8, s3][s3, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.eq.Tensor](args = (%argmax, %argmax_1), kwargs = {}) +# %squeeze : Tensor "i64[8, s3][s3, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%arg4_1, -1), kwargs = {}) +# %mul_7 : Tensor "i64[8, s3][s3, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%eq_2, %squeeze), kwargs = {}) +# %sum_1 : Tensor "i64[][]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_7,), kwargs = {}) +# return %buf3 +triton_red_fused_eq_mul_squeeze_sum_2 = async_compile.triton('triton_red_fused_eq_mul_squeeze_sum_2', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 2, 'r0_': 8192}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i64', 'in_ptr2': '*i64', 'out_ptr0': '*i64', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_eq_mul_squeeze_sum_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_eq_mul_squeeze_sum_2(in_ptr0, in_ptr1, in_ptr2, out_ptr0, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 2 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp7 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (ks0*((((r0_1 + 4*ks0*x0) // ks0) % 8)) + ((r0_1 % ks0))), r0_mask & xmask, eviction_policy='evict_last', other=0.0) + tmp1 = tl.load(in_ptr1 + (ks0*((((r0_1 + 4*ks0*x0) // ks0) % 8)) + ((r0_1 % ks0))), r0_mask & xmask, eviction_policy='evict_last', other=0.0) + tmp4 = tl.load(in_ptr2 + (ks0*((((r0_1 + 4*ks0*x0) // ks0) % 8)) + ((r0_1 % ks0))), r0_mask & xmask, eviction_policy='evict_last', other=0.0) + tmp2 = tmp0 == tmp1 + tmp3 = tmp2.to(tl.int64) + tmp5 = tmp3 * tmp4 + tmp6 = tl.broadcast_to(tmp5, [XBLOCK, R0_BLOCK]) + tmp8 = _tmp7 + tmp6 + _tmp7 = tl.where(r0_mask & xmask, tmp8, _tmp7) + tmp7 = tl.sum(_tmp7, 1)[:, None] + tl.store(out_ptr0 + (x0), tmp7, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/pf/cpfgylkqu7arpbcogudg5r4v6hqz6ojt7pirbokbuqke3dpgsed7.py +# Topologically Sorted Source Nodes: [sum_2], Original ATen: [aten.sum] +# Source node to ATen node mapping: +# sum_2 => sum_2 +# Graph fragment: +# %arg6_1 : Tensor "i64[8, s14, 1][s14, 1, 1]cuda:4" = PlaceHolder[target=arg6_1] +# %sum_2 : Tensor "i64[][]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%arg6_1,), kwargs = {}) +# return %buf5 +triton_red_fused_sum_3 = async_compile.triton('triton_red_fused_sum_3', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 2, 'r0_': 8192}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr0': '*i64', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_sum_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_sum_3(in_ptr0, out_ptr0, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 2 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp2 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (ks0*((((r0_1 + 4*ks0*x0) // ks0) % 8)) + ((r0_1 % ks0))), r0_mask & xmask, eviction_policy='evict_last', other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp3 = _tmp2 + tmp1 + _tmp2 = tl.where(r0_mask & xmask, tmp3, _tmp2) + tmp2 = tl.sum(_tmp2, 1)[:, None] + tl.store(out_ptr0 + (x0), tmp2, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/4l/c4lbz3jtnjjxbp7lftpjy4iam6ao6fc5cpp42bxihe27bm4qlhss.py +# Topologically Sorted Source Nodes: [eq, squeeze, mul, sum_1, sum_2, clamp_min, truediv], Original ATen: [aten.eq, aten.squeeze, aten.mul, aten.sum, aten.clamp_min, aten.div] +# Source node to ATen node mapping: +# clamp_min => clamp_min +# eq => eq_2 +# mul => mul_7 +# squeeze => squeeze +# sum_1 => sum_1 +# sum_2 => sum_2 +# truediv => div +# Graph fragment: +# %buf3 : Tensor "i64[2][1]cuda:4" = PlaceHolder[target=buf3] +# %buf5 : Tensor "i64[2][1]cuda:4" = PlaceHolder[target=buf5] +# %sum_1 : Tensor "i64[][]cuda:4" = PlaceHolder[target=sum_1] +# %sum_2 : Tensor "i64[][]cuda:4" = PlaceHolder[target=sum_2] +# %eq_2 : Tensor "b8[8, s3][s3, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.eq.Tensor](args = (%argmax, %argmax_1), kwargs = {}) +# %squeeze : Tensor "i64[8, s3][s3, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%arg4_1, -1), kwargs = {}) +# %mul_7 : Tensor "i64[8, s3][s3, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%eq_2, %squeeze), kwargs = {}) +# %sum_1 : Tensor "i64[][]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_7,), kwargs = {}) +# %sum_2 : Tensor "i64[][]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%arg6_1,), kwargs = {}) +# %clamp_min : Tensor "f32[][]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.clamp_min.default](args = (%sum_2, 1e-06), kwargs = {}) +# %div : Tensor "f32[][]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%sum_1, %clamp_min), kwargs = {}) +# return %sum_1,%sum_2,%div +triton_per_fused_clamp_min_div_eq_mul_squeeze_sum_4 = async_compile.triton('triton_per_fused_clamp_min_div_eq_mul_squeeze_sum_4', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 1, 'r0_': 2}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i64', 'out_ptr2': '*fp32', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {'xnumel': 1}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_clamp_min_div_eq_mul_squeeze_sum_4', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'r0_': 8}} +) +@triton.jit +def triton_per_fused_clamp_min_div_eq_mul_squeeze_sum_4(in_ptr0, in_ptr1, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr): + xnumel = 1 + r0_numel = 2 + R0_BLOCK: tl.constexpr = 2 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_0 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_0), None) + tmp4 = tl.load(in_ptr1 + (r0_0), None) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp3 = tl.sum(tmp1, 1)[:, None].to(tl.int64) + tmp5 = tl.broadcast_to(tmp4, [XBLOCK, R0_BLOCK]) + tmp7 = tl.sum(tmp5, 1)[:, None].to(tl.int64) + tmp8 = tmp3.to(tl.float32) + tmp9 = tmp7.to(tl.float32) + tmp10 = 1e-06 + tmp11 = triton_helpers.maximum(tmp9, tmp10) + tmp12 = (tmp8 / tmp11) + tl.store(out_ptr2 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp12, None) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1 = args + args.clear() + s3 = arg0_1 + s71 = arg2_1 + s14 = arg5_1 + assert_size_stride(arg1_1, (8, s3, 32000), (32000*s3, 32000, 1)) + assert_size_stride(arg3_1, (8, s3, 32000), (s71, 32000, 1)) + assert_size_stride(arg4_1, (8, s3, 1), (s3, 1, 1)) + assert_size_stride(arg6_1, (8, s14, 1), (s14, 1, 1)) + with torch.cuda._DeviceGuard(4): + torch.cuda.set_device(4) + buf0 = empty_strided_cuda((8, s3), (s3, 1), torch.int64) + # Topologically Sorted Source Nodes: [argmax], Original ATen: [aten.argmax] + triton_red_fused_argmax_0_xnumel = 8*s3 + stream4 = get_raw_stream(4) + triton_red_fused_argmax_0.run(arg1_1, buf0, triton_red_fused_argmax_0_xnumel, 32000, stream=stream4) + del arg1_1 + buf1 = empty_strided_cuda((8, s3), (s3, 1), torch.int64) + # Topologically Sorted Source Nodes: [argmax_1], Original ATen: [aten.argmax] + triton_red_fused_argmax_1_xnumel = 8*s3 + stream4 = get_raw_stream(4) + triton_red_fused_argmax_1.run(arg3_1, buf1, s3, s71, triton_red_fused_argmax_1_xnumel, 32000, stream=stream4) + del arg3_1 + buf3 = empty_strided_cuda((2, ), (1, ), torch.int64) + # Topologically Sorted Source Nodes: [eq, squeeze, mul, sum_1], Original ATen: [aten.eq, aten.squeeze, aten.mul, aten.sum] + triton_red_fused_eq_mul_squeeze_sum_2_r0_numel = 4*s3 + stream4 = get_raw_stream(4) + triton_red_fused_eq_mul_squeeze_sum_2.run(buf0, buf1, arg4_1, buf3, s3, 2, triton_red_fused_eq_mul_squeeze_sum_2_r0_numel, stream=stream4) + del arg4_1 + del buf0 + del buf1 + buf5 = empty_strided_cuda((2, ), (1, ), torch.int64) + # Topologically Sorted Source Nodes: [sum_2], Original ATen: [aten.sum] + triton_red_fused_sum_3_r0_numel = 4*s14 + stream4 = get_raw_stream(4) + triton_red_fused_sum_3.run(arg6_1, buf5, s14, 2, triton_red_fused_sum_3_r0_numel, stream=stream4) + del arg6_1 + buf7 = empty_strided_cuda((), (), torch.float32) + # Topologically Sorted Source Nodes: [eq, squeeze, mul, sum_1, sum_2, clamp_min, truediv], Original ATen: [aten.eq, aten.squeeze, aten.mul, aten.sum, aten.clamp_min, aten.div] + stream4 = get_raw_stream(4) + triton_per_fused_clamp_min_div_eq_mul_squeeze_sum_4.run(buf3, buf5, buf7, 1, 2, stream=stream4) + del buf3 + del buf5 + return (buf7, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = 1896 + arg1_1 = rand_strided((8, 1896, 32000), (60672000, 32000, 1), device='cuda:4', dtype=torch.bfloat16) + arg2_1 = 60896000 + arg3_1 = rand_strided((8, 1896, 32000), (60896000, 32000, 1), device='cuda:4', dtype=torch.float32) + arg4_1 = rand_strided((8, 1896, 1), (1896, 1, 1), device='cuda:4', dtype=torch.int64) + arg5_1 = 1896 + arg6_1 = rand_strided((8, 1896, 1), (1896, 1, 1), device='cuda:4', dtype=torch.int64) + fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/i7/ci7k4nsws2k2ul2efnmk5kpdyf23awz656g7ls3el2bksnqpwzrz.py b/SpecForge-ext/cache/compiled_kernels/i7/ci7k4nsws2k2ul2efnmk5kpdyf23awz656g7ls3el2bksnqpwzrz.py new file mode 100644 index 0000000000000000000000000000000000000000..f743529badde0a8db03fa6c85e75d619b43bf0a3 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/i7/ci7k4nsws2k2ul2efnmk5kpdyf23awz656g7ls3el2bksnqpwzrz.py @@ -0,0 +1,46 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 65536, 'r0_': 262144000}} +) +@triton.jit +def triton_red_fused_argmax_0(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 4096 + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tl.store(out_ptr0 + (x0), tmp2, None) diff --git a/SpecForge-ext/cache/compiled_kernels/il/cilp5kqvrljsbeu2eadyyvf76cdxypu34m6m4bfrk3qitwvhuaei.py b/SpecForge-ext/cache/compiled_kernels/il/cilp5kqvrljsbeu2eadyyvf76cdxypu34m6m4bfrk3qitwvhuaei.py new file mode 100644 index 0000000000000000000000000000000000000000..6332682573237289cac22eefe45b8eb6ac9f8c37 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/il/cilp5kqvrljsbeu2eadyyvf76cdxypu34m6m4bfrk3qitwvhuaei.py @@ -0,0 +1,552 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'in_ptr9': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks0, 128*ks0, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks0, 128*ks0, 128, 1 + + ZQ = 8 + HQ = 32 + Q_LEN = 2048 + ZKV = 8 + KV_LEN = ks0 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 8 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 16*ks1 + stride_kv_idx_m = ks1 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 262144*idx_hq + 8388608*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m + 8388608*idx_zq, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr9 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/io/ciomjhm6gvshkufbv7mvkwm7tur3yulj2gzni6c6dk5wy5ngcabp.py b/SpecForge-ext/cache/compiled_kernels/io/ciomjhm6gvshkufbv7mvkwm7tur3yulj2gzni6c6dk5wy5ngcabp.py new file mode 100644 index 0000000000000000000000000000000000000000..49679c077131ea860738876dac02249f1e9c1f29 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/io/ciomjhm6gvshkufbv7mvkwm7tur3yulj2gzni6c6dk5wy5ngcabp.py @@ -0,0 +1,527 @@ +# AOT ID: ['8_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uq/cuq3kpq5p236ryparerpt762hx5ibuzufssmu2dgqtdmebfebtve.py +# Topologically Sorted Source Nodes: [dense_mask_2], Original ATen: [aten.new_zeros] +# Source node to ATen node mapping: +# dense_mask_2 => full_default_1 +# Graph fragment: +# %full_default_1 : Tensor "i32[8, 1, 16, (((s37 + 127)//128)) + 1][16*Max(1, (((s37 + 127)//128)) + 1), 16*Max(1, (((s37 + 127)//128)) + 1), Max(1, (((s37 + 127)//128)) + 1), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 1, 16, %add_166], 0), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:5, pin_memory: False}) +# return %index_put +triton_poi_fused_new_zeros_0 = async_compile.triton('triton_poi_fused_new_zeros_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 8192}, + filename=__file__, + triton_meta={'signature': {'out_ptr0': '*i32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_new_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_new_zeros_0(out_ptr0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.full([1], 0, tl.int32) + tl.store(out_ptr0 + (x0), tmp0, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/da/cdayidk2nqaqq2zcz2eccpmjrvlgr344k5o6lxh7dhtpu3ktudj5.py +# Topologically Sorted Source Nodes: [result_1, m, causal_mask, n, b, index, lt, padding_mask, index_1, lt_1, and_2, suffix_mask, remainder, index_2, padding_mask_1, and_3, and_4, sub, remainder_1, diagnol_mask, result_2, batched_outputs_2, mask_1, mask_2, mask_3, mask_block_sum, gt, lt_3, partial_blocks, partial_blocks_1, dense_mask, full_blocks, full_blocks_1, dense_mask_1], Original ATen: [aten.view, aten.arange, aten.ge, aten.index, aten.lt, aten.bitwise_and, aten.bitwise_or, aten.remainder, aten.sub, aten.eq, aten.constant_pad_nd, aten.permute, aten.sum, aten.gt, aten._to_copy] +# Source node to ATen node mapping: +# and_2 => bitwise_and_1 +# and_3 => bitwise_and_2 +# and_4 => bitwise_and_3, view_8 +# b => iota +# batched_outputs_2 => view_9 +# causal_mask => ge_1, view +# dense_mask => convert_element_type_2 +# dense_mask_1 => convert_element_type_5 +# diagnol_mask => eq_12 +# full_blocks => eq_24 +# full_blocks_1 => convert_element_type_1 +# gt => gt +# index => index +# index_1 => index_1 +# index_2 => index_2 +# lt => lt, view_1 +# lt_1 => lt_1, view_2 +# lt_3 => lt_3 +# m => iota_2 +# mask_1 => constant_pad_nd +# mask_2 => view_10 +# mask_3 => permute +# mask_block_sum => sum_1 +# n => iota_3 +# padding_mask => bitwise_and, view_3, view_4 +# padding_mask_1 => lt_2, view_6 +# partial_blocks => bitwise_and_4 +# partial_blocks_1 => convert_element_type +# remainder => remainder +# remainder_1 => remainder_1 +# result_1 => bitwise_or, full_default +# result_2 => bitwise_or_1 +# sub => sub_12, view_7 +# suffix_mask => ge_2 +# Graph fragment: +# %arg1_1 : Tensor "i64[8][1]cuda:5" = PlaceHolder[target=arg1_1] +# %sum_1 : Tensor "i64[8, 1, 16, ((s37 + 127)//128)][16*(((s37 + 127)//128)), 128*(((s37 + 127)//128)), ((s37 + 127)//128), 1]cuda:5" = PlaceHolder[target=sum_1] +# %full_default : Tensor "b8[8, 1, 1][1, 1, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 1, 1], False), kwargs = {dtype: torch.bool, layout: torch.strided, device: cuda:5, pin_memory: False}) +# %iota_2 : Tensor "i64[2048][1]cuda:5"[num_users=3] = call_function[target=torch.ops.prims.iota.default](args = (2048,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:5, requires_grad: False}) +# %view : Tensor "i64[2048, 1][1, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%iota_2, [2048, 1]), kwargs = {}) +# %iota_3 : Tensor "i64[s37][1]cuda:5"[num_users=5] = call_function[target=torch.ops.prims.iota.default](args = (%arg0_1,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:5, requires_grad: False}) +# %ge_1 : Tensor "b8[2048, s37][Max(1, s37), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.ge.Tensor](args = (%view, %iota_3), kwargs = {}) +# %iota : Tensor "i64[8][1]cuda:5"[num_users=3] = call_function[target=torch.ops.prims.iota.default](args = (8,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:5, requires_grad: False}) +# %index : Tensor "i64[8][1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg1_1, [%iota]), kwargs = {}) +# %view_1 : Tensor "i64[8, 1][1, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%index, [8, 1]), kwargs = {}) +# %lt : Tensor "b8[8, s37][Max(1, s37), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_3, %view_1), kwargs = {}) +# %view_4 : Tensor "b8[8, 1, s37][Max(1, s37), s37, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%lt, [8, 1, %arg0_1]), kwargs = {}) +# %index_1 : Tensor "i64[8][1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg1_1, [%iota]), kwargs = {}) +# %view_2 : Tensor "i64[8, 1][1, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%index_1, [8, 1]), kwargs = {}) +# %lt_1 : Tensor "b8[8, 2048][2048, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_2, %view_2), kwargs = {}) +# %view_3 : Tensor "b8[8, 2048, 1][2048, 1, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%lt_1, [8, 2048, 1]), kwargs = {}) +# %bitwise_and : Tensor "b8[8, 2048, s37][2048*Max(1, s37), Max(1, s37), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%view_4, %view_3), kwargs = {}) +# %bitwise_and_1 : Tensor "b8[8, 2048, s37][2048*Max(1, s37), Max(1, s37), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%ge_1, %bitwise_and), kwargs = {}) +# %bitwise_or : Tensor "b8[8, 2048, s37][2048*Max(1, s37), Max(1, s37), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.bitwise_or.Tensor](args = (%full_default, %bitwise_and_1), kwargs = {}) +# %ge_2 : Tensor "b8[s37][1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.ge.Scalar](args = (%iota_3, 2048), kwargs = {}) +# %remainder : Tensor "i64[s37][1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.remainder.Scalar](args = (%iota_3, 2048), kwargs = {}) +# %index_2 : Tensor "i64[8][1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg1_1, [%iota]), kwargs = {}) +# %view_6 : Tensor "i64[8, 1][1, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%index_2, [8, 1]), kwargs = {}) +# %lt_2 : Tensor "b8[8, s37][Max(1, s37), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%remainder, %view_6), kwargs = {}) +# %bitwise_and_2 : Tensor "b8[8, s37][Max(1, s37), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%ge_2, %lt_2), kwargs = {}) +# %view_8 : Tensor "b8[8, 1, s37][Max(1, s37), s37, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%bitwise_and_2, [8, 1, %arg0_1]), kwargs = {}) +# %view_7 : Tensor "i64[2048, 1][1, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%iota_2, [2048, 1]), kwargs = {}) +# %sub_12 : Tensor "i64[2048, s37][Max(1, s37), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%iota_3, %view_7), kwargs = {}) +# %remainder_1 : Tensor "i64[2048, s37][Max(1, s37), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.remainder.Scalar](args = (%sub_12, 2048), kwargs = {}) +# %eq_12 : Tensor "b8[2048, s37][Max(1, s37), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%remainder_1, 0), kwargs = {}) +# %bitwise_and_3 : Tensor "b8[8, 2048, s37][2048*Max(1, s37), Max(1, s37), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%view_8, %eq_12), kwargs = {}) +# %bitwise_or_1 : Tensor "b8[8, 2048, s37][2048*Max(1, s37), Max(1, s37), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.bitwise_or.Tensor](args = (%bitwise_or, %bitwise_and_3), kwargs = {}) +# %view_9 : Tensor "b8[8, 1, 2048, s37][2048*Max(1, s37), 2048*Max(1, s37), Max(1, s37), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%bitwise_or_1, [8, 1, 2048, %arg0_1]), kwargs = {}) +# %constant_pad_nd : Tensor "b8[8, 1, 2048, 128*(((s37 + 127)//128))][2048*Max(1, 128*(((s37 + 127)//128))), 2048*Max(1, 128*(((s37 + 127)//128))), Max(1, 128*(((s37 + 127)//128))), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.constant_pad_nd.default](args = (%expand, [0, %sub_23, 0, 0], 0.0), kwargs = {}) +# %view_10 : Tensor "b8[8, 1, 16, 128, ((s37 + 127)//128), 128][2048*Max(1, 128*(((s37 + 127)//128))), 2048*Max(1, 128*(((s37 + 127)//128))), 128*Max(1, 128*(((s37 + 127)//128))), Max(1, 128*(((s37 + 127)//128))), 128, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%constant_pad_nd, [8, 1, 16, 128, %floordiv_1, 128]), kwargs = {}) +# %permute : Tensor "b8[8, 1, 16, ((s37 + 127)//128), 128, 128][2048*Max(1, 128*(((s37 + 127)//128))), 2048*Max(1, 128*(((s37 + 127)//128))), 128*Max(1, 128*(((s37 + 127)//128))), 128, Max(1, 128*(((s37 + 127)//128))), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_10, [0, 1, 2, 4, 3, 5]), kwargs = {}) +# %sum_1 : Tensor "i64[8, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:5"[num_users=3] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%permute, [-2, -1]), kwargs = {}) +# %gt : Tensor "b8[8, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {}) +# %lt_3 : Tensor "b8[8, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.lt.Scalar](args = (%sum_1, 16384), kwargs = {}) +# %bitwise_and_4 : Tensor "b8[8, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%gt, %lt_3), kwargs = {}) +# %convert_element_type : Tensor "i8[8, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%bitwise_and_4, torch.int8), kwargs = {}) +# %convert_element_type_2 : Tensor "i32[8, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:5"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%convert_element_type, torch.int32), kwargs = {}) +# %eq_24 : Tensor "b8[8, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%sum_1, 16384), kwargs = {}) +# %convert_element_type_1 : Tensor "i8[8, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%eq_24, torch.int8), kwargs = {}) +# %convert_element_type_5 : Tensor "i32[8, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:5"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%convert_element_type_1, torch.int32), kwargs = {}) +# return %sum_1,%convert_element_type_2,%convert_element_type_5 +triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1 = async_compile.triton('triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 16384}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr1': '*i32', 'out_ptr2': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1(in_ptr0, out_ptr1, out_ptr2, ks0, ks1, ks2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 16384 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % ks0) + x1 = ((xindex // ks0) % 16) + x2 = xindex // ks2 + _tmp36 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + x5 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_3 = (r0_index % 128) + r0_4 = r0_index // 128 + tmp0 = r0_3 + 128*x0 + tmp1 = ks1 + tmp2 = tmp0 < tmp1 + tmp3 = r0_4 + 128*x1 + tmp4 = r0_3 + 128*x0 + tmp5 = tmp3 >= tmp4 + tmp6 = tl.load(in_ptr0 + (tl.broadcast_to(x2, [XBLOCK, R0_BLOCK])), r0_mask & tmp2 & xmask, eviction_policy='evict_last', other=0.0) + tmp7 = tmp4 < tmp6 + tmp8 = tmp3 < tmp6 + tmp9 = tmp7 & tmp8 + tmp10 = tmp5 & tmp9 + tmp11 = tl.full([1, 1], False, tl.int1) + tmp12 = tmp11 | tmp10 + tmp13 = tl.full([1, 1], 2048, tl.int64) + tmp14 = tmp4 >= tmp13 + tmp15 = ((r0_3 + 128*x0) % 2048) + tmp16 = tmp15 < tmp6 + tmp17 = tmp14 & tmp16 + tmp18 = r0_3 + ((-1)*r0_4) + ((-128)*x1) + 128*x0 + tmp19 = (tmp18 % tmp13) + tmp20 = tl.full([1, 1], 0, tl.int32) + tmp21 = tmp19 != tmp20 + tmp22 = (libdevice.signbit(tmp19) != 0) if (tmp19).dtype is tl.float32 else tmp19 < 0 + tmp23 = (libdevice.signbit(tmp13) != 0) if (tmp13).dtype is tl.float32 else tmp13 < 0 + tmp24 = tmp22 != tmp23 + tmp25 = tmp21 & tmp24 + tmp26 = tmp19 + tmp13 + tmp27 = tl.where(tmp25, tmp26, tmp19) + tmp28 = tl.full([1, 1], 0, tl.int64) + tmp29 = tmp27 == tmp28 + tmp30 = tmp17 & tmp29 + tmp31 = tmp12 | tmp30 + tmp32 = tl.full(tmp31.shape, False, tmp31.dtype) + tmp33 = tl.where(tmp2, tmp31, tmp32) + tmp34 = tmp33.to(tl.int64) + tmp35 = tl.broadcast_to(tmp34, [XBLOCK, R0_BLOCK]) + tmp37 = _tmp36 + tmp35 + _tmp36 = tl.where(r0_mask & xmask, tmp37, _tmp36) + tmp36 = tl.sum(_tmp36, 1)[:, None] + tmp38 = tl.full([1, 1], 0, tl.int64) + tmp39 = tmp36 > tmp38 + tmp40 = tl.full([1, 1], 16384, tl.int64) + tmp41 = tmp36 < tmp40 + tmp42 = tmp39 & tmp41 + tmp43 = tmp42.to(tl.int8) + tmp44 = tmp43.to(tl.int32) + tmp45 = tmp36 == tmp40 + tmp46 = tmp45.to(tl.int8) + tmp47 = tmp46.to(tl.int32) + tl.store(out_ptr1 + (x5), tmp44, xmask) + tl.store(out_ptr2 + (x5), tmp47, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/d5/cd5lwybnbjetoaf6hxajj7itqmrk3fj4xejz52d5s2w56qouijor.py +# Topologically Sorted Source Nodes: [dense_mask_2, setitem, arange_4, row_indices, col_range, num_blocks_in_row, child_3, unsqueeze_1, index_mask, child_4, valid_indices], Original ATen: [aten.new_zeros, aten.arange, aten.unsqueeze, aten.sum, aten._to_copy, aten.lt, aten.scalar_tensor, aten.where, aten.view, aten.index_put] +# Source node to ATen node mapping: +# arange_4 => iota_4 +# child_3 => convert_element_type_3 +# child_4 => convert_element_type_4 +# col_range => iota_5 +# dense_mask_2 => full_default_1 +# index_mask => lt_4 +# num_blocks_in_row => sum_2 +# row_indices => unsqueeze +# setitem => full_default_2, index_put, iota_6, iota_7, unsqueeze_2, unsqueeze_3, unsqueeze_4, unsqueeze_5, unsqueeze_6 +# unsqueeze_1 => unsqueeze_1 +# valid_indices => scalar_tensor, where +# Graph fragment: +# %convert_element_type_2 : Tensor "i32[8, 1, 16, ((s37 + 127)//128)][16*(((s37 + 127)//128)), 128*(((s37 + 127)//128)), ((s37 + 127)//128), 1]cuda:5" = PlaceHolder[target=convert_element_type_2] +# %sum_2 : Tensor "i64[8, 1, 16][16, 128, 1]cuda:5" = PlaceHolder[target=sum_2] +# %getitem_1 : Tensor "i64[8, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 128*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:5" = PlaceHolder[target=getitem_1] +# %convert_element_type_3 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:5" = PlaceHolder[target=convert_element_type_3] +# %convert_element_type_4 : Tensor "i32[8, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:5" = PlaceHolder[target=convert_element_type_4] +# %index_put : Tensor "i32[8, 1, 16, (((s37 + 127)//128)) + 1][16*(((s37 + 127)//128)) + 16, 16*(((s37 + 127)//128)) + 16, (((s37 + 127)//128)) + 1, 1]cuda:5" = PlaceHolder[target=index_put] +# %full_default_1 : Tensor "i32[8, 1, 16, (((s37 + 127)//128)) + 1][16*Max(1, (((s37 + 127)//128)) + 1), 16*Max(1, (((s37 + 127)//128)) + 1), Max(1, (((s37 + 127)//128)) + 1), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 1, 16, %add_166], 0), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:5, pin_memory: False}) +# %iota_7 : Tensor "i64[8][1]cuda:5"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (8,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:5, requires_grad: False}) +# %unsqueeze_4 : Tensor "i64[8, 1][1, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_7, -1), kwargs = {}) +# %unsqueeze_5 : Tensor "i64[8, 1, 1][1, 1, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_4, -1), kwargs = {}) +# %unsqueeze_6 : Tensor "i64[8, 1, 1, 1][1, 1, 1, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_5, -1), kwargs = {}) +# %iota_6 : Tensor "i64[1][1]cuda:5"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (1,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:5, requires_grad: False}) +# %unsqueeze_2 : Tensor "i64[1, 1][1, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_6, -1), kwargs = {}) +# %unsqueeze_3 : Tensor "i64[1, 1, 1][1, 1, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_2, -1), kwargs = {}) +# %iota_4 : Tensor "i32[16][1]cuda:5"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (16,), kwargs = {start: 0, step: 1, dtype: torch.int32, device: cuda:5, requires_grad: False}) +# %unsqueeze : Tensor "i32[16, 1][1, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_4, -1), kwargs = {}) +# %iota_5 : Tensor "i32[((s37 + 127)//128)][1]cuda:5"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (%floordiv_1,), kwargs = {start: 0, step: 1, dtype: torch.int32, device: cuda:5, requires_grad: False}) +# %sum_2 : Tensor "i64[8, 1, 16][16, 16, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%convert_element_type_2, [-1]), kwargs = {}) +# %convert_element_type_3 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:5"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_2, torch.int32), kwargs = {}) +# %unsqueeze_1 : Tensor "i32[8, 1, 16, 1][16, 16, 1, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_3, 3), kwargs = {}) +# %lt_4 : Tensor "b8[8, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_5, %unsqueeze_1), kwargs = {}) +# %convert_element_type_4 : Tensor "i32[8, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:5"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_1, torch.int32), kwargs = {}) +# %scalar_tensor : Tensor "i32[][]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.scalar_tensor.default](args = (%floordiv_1,), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:5}) +# %where : Tensor "i32[8, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%lt_4, %convert_element_type_4, %scalar_tensor), kwargs = {}) +# %full_default_2 : Tensor "i32[8, 1, 1, 1][1, 1, 1, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 1, 1, 1], 1), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:5, pin_memory: False}) +# %index_put : Tensor "i32[8, 1, 16, (((s37 + 127)//128)) + 1][16*Max(1, (((s37 + 127)//128)) + 1), 16*Max(1, (((s37 + 127)//128)) + 1), Max(1, (((s37 + 127)//128)) + 1), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.index_put_.default](args = (%full_default_1, [%unsqueeze_6, %unsqueeze_3, %unsqueeze, %where], %full_default_2), kwargs = {}) +# return %sum_2,%convert_element_type_3,%convert_element_type_4,%buf13 +triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2 = async_compile.triton('triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 128, 'r0_': 32}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'in_ptr1': '*i64', 'out_ptr1': '*i32', 'out_ptr2': '*i32', 'out_ptr3': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2', 'mutated_arg_names': ['out_ptr3'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2(in_ptr0, in_ptr1, out_ptr1, out_ptr2, out_ptr3, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp3 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp1 = tmp0.to(tl.int64) + tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) + tmp4 = _tmp3 + tmp2 + _tmp3 = tl.where(r0_mask & xmask, tmp4, _tmp3) + tmp3 = tl.sum(_tmp3, 1)[:, None] + tmp5 = tmp3.to(tl.int32) + tl.store(out_ptr1 + (x0), tmp5, xmask) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp6 = tl.load(in_ptr1 + (r0_1 + x0*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp7 = tmp6.to(tl.int32) + tmp8 = r0_1 + tmp9 = tmp8 < tmp5 + tmp10 = ks0 + tmp11 = tl.where(tmp9, tmp7, tmp10) + tmp12 = 1 + ks0 + tmp13 = tmp11 + tmp12 + tmp14 = tmp11 < 0 + tmp15 = tl.where(tmp14, tmp13, tmp11) + tl.device_assert(((0 <= tmp15) & (tmp15 < 1 + (triton_helpers.div_floor_integer(127 + ks1, 128)))) | ~(r0_mask & xmask), "index out of bounds: 0 <= tmp15 < 1 + (triton_helpers.div_floor_integer(127 + ks1, 128))") + tmp17 = tl.full([1, 1], 1, tl.int32) + tl.store(out_ptr2 + (r0_1 + x0*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp7, r0_mask & xmask) + tl.store(out_ptr3 + (tl.broadcast_to(tmp15 + x0 + ks0*x0, [XBLOCK, R0_BLOCK])), tmp17, r0_mask & xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/7t/c7t3uvardqlt6x3sz37tlydghb4rt6mdilzlc7ffz3pehdn5jwdj.py +# Topologically Sorted Source Nodes: [batched_outputs_3, transpose, col_indices_2, q_indices, num_blocks_in_row_2, q_num_blocks], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sort, aten._to_copy, aten.sum] +# Source node to ATen node mapping: +# batched_outputs_3 => clone_4, slice_2 +# col_indices_2 => sort_2 +# num_blocks_in_row_2 => sum_4 +# q_indices => clone_6, convert_element_type_9 +# q_num_blocks => convert_element_type_8 +# transpose => permute_1 +# Graph fragment: +# %buf13 : Tensor "i32[8, 1, 16, (((s37 + 127)//128)) + 1][16*(((s37 + 127)//128)) + 16, 16*(((s37 + 127)//128)) + 16, (((s37 + 127)//128)) + 1, 1]cuda:5" = PlaceHolder[target=buf13] +# %buf15 : Tensor "i16[8, 1, ((s37 + 127)//128), 16][16*(((s37 + 127)//128)), 128*(((s37 + 127)//128)), 16, 1]cuda:5" = PlaceHolder[target=buf15] +# %sum_4 : Tensor "i64[8, 1, ((s37 + 127)//128)][((s37 + 127)//128), 8*(((s37 + 127)//128)), 1]cuda:5" = PlaceHolder[target=sum_4] +# %slice_2 : Tensor "i32[8, 1, 16, ((s37 + 127)//128)][16*Max(1, (((s37 + 127)//128)) + 1), 16*Max(1, (((s37 + 127)//128)) + 1), Max(1, (((s37 + 127)//128)) + 1), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%index_put, 3, 0, %floordiv_1), kwargs = {}) +# %clone_4 : Tensor "i32[8, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%slice_2,), kwargs = {memory_format: torch.contiguous_format}) +# %permute_1 : Tensor "i32[8, 1, ((s37 + 127)//128), 16][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), 1, Max(1, ((s37 + 127)//128))]cuda:5"[num_users=2] = call_function[target=torch.ops.aten.permute.default](args = (%clone_4, [0, 1, 3, 2]), kwargs = {}) +# %sort_2 : [num_users=1] = call_function[target=torch.ops.aten.sort.stable](args = (%permute_1,), kwargs = {stable: True, descending: True}) +# %convert_element_type_9 : Tensor "i32[8, 1, ((s37 + 127)//128), 16][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), 1, Max(1, ((s37 + 127)//128))]cuda:5"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_5, torch.int32), kwargs = {}) +# %clone_6 : Tensor "i32[8, 1, ((s37 + 127)//128), 16][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), 16, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%convert_element_type_9,), kwargs = {memory_format: torch.contiguous_format}) +# %sum_4 : Tensor "i64[8, 1, ((s37 + 127)//128)][Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%permute_1, [-1]), kwargs = {}) +# %convert_element_type_8 : Tensor "i32[8, 1, ((s37 + 127)//128)][Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_4, torch.int32), kwargs = {}) +# return %buf15,%sum_4,%clone_6,%convert_element_type_8 +triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3 = async_compile.triton('triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 256, 'r0_': 16}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr2': '*i32', 'out_ptr3': '*i32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3(in_ptr0, out_ptr2, out_ptr3, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 16 + R0_BLOCK: tl.constexpr = 16 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + x0 = (xindex % ks0) + x1 = xindex // ks0 + x3 = xindex + tmp0 = tl.load(in_ptr0 + (r0_2 + x0 + 16*x1 + ks0*r0_2 + 16*ks0*x1), xmask, eviction_policy='evict_last', other=0.0) + tmp1 = r0_2 + tmp2 = tmp1.to(tl.int16) + tmp3 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp4 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5, tmp6, = triton_helpers.sort_with_index(tmp3, tmp4, None, 1, stable=True, descending=True) + tmp7 = tmp0.to(tl.int64) + tmp8 = tl.broadcast_to(tmp7, [XBLOCK, R0_BLOCK]) + tmp10 = tl.where(xmask, tmp8, 0) + tmp11 = tl.sum(tmp10, 1)[:, None].to(tl.int64) + tmp12 = tmp6.to(tl.int64) + tmp13 = tmp12.to(tl.int32) + tmp14 = tmp11.to(tl.int32) + tl.store(out_ptr2 + (r0_2 + 16*x0 + 16*x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp13, xmask) + tl.store(out_ptr3 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp14, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1 = args + args.clear() + s37 = arg0_1 + assert_size_stride(arg1_1, (8, ), (1, )) + with torch.cuda._DeviceGuard(5): + torch.cuda.set_device(5) + buf12 = empty_strided_cuda((8, 1, 16, 1 + ((127 + s37) // 128)), (16 + 16*((127 + s37) // 128), 16 + 16*((127 + s37) // 128), 1 + ((127 + s37) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [dense_mask_2], Original ATen: [aten.new_zeros] + triton_poi_fused_new_zeros_0_xnumel = 128 + 128*((127 + s37) // 128) + stream5 = get_raw_stream(5) + triton_poi_fused_new_zeros_0.run(buf12, triton_poi_fused_new_zeros_0_xnumel, stream=stream5) + buf19 = empty_strided_cuda((8, 1, 16, 1 + ((127 + s37) // 128)), (16 + 16*((127 + s37) // 128), 16 + 16*((127 + s37) // 128), 1 + ((127 + s37) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [dense_mask_4], Original ATen: [aten.new_zeros] + triton_poi_fused_new_zeros_0_xnumel = 128 + 128*((127 + s37) // 128) + stream5 = get_raw_stream(5) + triton_poi_fused_new_zeros_0.run(buf19, triton_poi_fused_new_zeros_0_xnumel, stream=stream5) + ps0 = (127 + s37) // 128 + ps1 = 16*((127 + s37) // 128) + buf1 = empty_strided_cuda((8, 1, 16, (127 + s37) // 128), (16*((127 + s37) // 128), 128*((127 + s37) // 128), (127 + s37) // 128, 1), torch.int32) + buf5 = empty_strided_cuda((8, 1, 16, (127 + s37) // 128), (16*((127 + s37) // 128), 128*((127 + s37) // 128), (127 + s37) // 128, 1), torch.int32) + # Topologically Sorted Source Nodes: [result_1, m, causal_mask, n, b, index, lt, padding_mask, index_1, lt_1, and_2, suffix_mask, remainder, index_2, padding_mask_1, and_3, and_4, sub, remainder_1, diagnol_mask, result_2, batched_outputs_2, mask_1, mask_2, mask_3, mask_block_sum, gt, lt_3, partial_blocks, partial_blocks_1, dense_mask, full_blocks, full_blocks_1, dense_mask_1], Original ATen: [aten.view, aten.arange, aten.ge, aten.index, aten.lt, aten.bitwise_and, aten.bitwise_or, aten.remainder, aten.sub, aten.eq, aten.constant_pad_nd, aten.permute, aten.sum, aten.gt, aten._to_copy] + triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1_xnumel = 128*((127 + s37) // 128) + stream5 = get_raw_stream(5) + triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1.run(arg1_1, buf1, buf5, ps0, s37, ps1, triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1_xnumel, 16384, stream=stream5) + del arg1_1 + # Topologically Sorted Source Nodes: [gt, lt_3, partial_blocks, partial_blocks_1, dense_mask, col_indices], Original ATen: [aten.gt, aten.lt, aten.bitwise_and, aten._to_copy, aten.sort] + buf2 = torch.ops.aten.sort.stable(buf1, stable=True, dim=3, descending=True) + buf4 = buf2[1] + assert_size_stride(buf4, (8, 1, 16, (127 + s37) // 128), (16*max(1, (127 + s37) // 128), 128*max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), 'torch.ops.aten.sort.stable') + assert_alignment(buf4, 16, 'torch.ops.aten.sort.stable') + del buf2 + buf10 = empty_strided_cuda((8, 1, 16), (16, 16, 1), torch.int32) + buf11 = empty_strided_cuda((8, 1, 16, (127 + s37) // 128), (16*max(1, (127 + s37) // 128), 16*max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [dense_mask_2, setitem, arange_4, row_indices, col_range, num_blocks_in_row, child_3, unsqueeze_1, index_mask, child_4, valid_indices], Original ATen: [aten.new_zeros, aten.arange, aten.unsqueeze, aten.sum, aten._to_copy, aten.lt, aten.scalar_tensor, aten.where, aten.view, aten.index_put] + triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2_r0_numel = (127 + s37) // 128 + stream5 = get_raw_stream(5) + triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2.run(buf1, buf4, buf10, buf11, buf12, ps0, s37, 128, triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2_r0_numel, stream=stream5) + del buf1 + del buf4 + buf26 = empty_strided_cuda((8, 1, (127 + s37) // 128, 16), (16*max(1, (127 + s37) // 128), 16*max(1, (127 + s37) // 128), 16, 1), torch.int32) + buf28 = empty_strided_cuda((8, 1, (127 + s37) // 128), (max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [batched_outputs_3, transpose, col_indices_2, q_indices, num_blocks_in_row_2, q_num_blocks], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sort, aten._to_copy, aten.sum] + triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3_xnumel = 8*((127 + s37) // 128) + stream5 = get_raw_stream(5) + triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3.run(buf12, buf26, buf28, ps0, triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3_xnumel, 16, stream=stream5) + del buf12 + # Topologically Sorted Source Nodes: [full_blocks, full_blocks_1, dense_mask_1, col_indices_1], Original ATen: [aten.eq, aten._to_copy, aten.sort] + buf6 = torch.ops.aten.sort.stable(buf5, stable=True, dim=3, descending=True) + buf8 = buf6[1] + assert_size_stride(buf8, (8, 1, 16, (127 + s37) // 128), (16*max(1, (127 + s37) // 128), 128*max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), 'torch.ops.aten.sort.stable') + assert_alignment(buf8, 16, 'torch.ops.aten.sort.stable') + del buf6 + buf17 = empty_strided_cuda((8, 1, 16), (16, 16, 1), torch.int32) + buf18 = empty_strided_cuda((8, 1, 16, (127 + s37) // 128), (16*max(1, (127 + s37) // 128), 16*max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [dense_mask_4, setitem_1, arange_6, row_indices_1, col_range_1, num_blocks_in_row_1, child_7, unsqueeze_3, index_mask_1, child_8, valid_indices_1], Original ATen: [aten.new_zeros, aten.arange, aten.unsqueeze, aten.sum, aten._to_copy, aten.lt, aten.scalar_tensor, aten.where, aten.view, aten.index_put] + triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2_r0_numel = (127 + s37) // 128 + stream5 = get_raw_stream(5) + triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2.run(buf5, buf8, buf17, buf18, buf19, ps0, s37, 128, triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2_r0_numel, stream=stream5) + del buf5 + del buf8 + buf23 = empty_strided_cuda((8, 1, (127 + s37) // 128, 16), (16*max(1, (127 + s37) // 128), 16*max(1, (127 + s37) // 128), 16, 1), torch.int32) + buf25 = empty_strided_cuda((8, 1, (127 + s37) // 128), (max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [batched_outputs_5, transpose_1, col_indices_3, full_q_indices, num_blocks_in_row_3, full_q_num_blocks], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sort, aten._to_copy, aten.sum] + triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3_xnumel = 8*((127 + s37) // 128) + stream5 = get_raw_stream(5) + triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3.run(buf19, buf23, buf25, ps0, triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3_xnumel, 16, stream=stream5) + del buf19 + return (buf23, buf25, buf26, buf28, buf18, buf17, buf11, buf10, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = 4096 + arg1_1 = rand_strided((8, ), (1, ), device='cuda:5', dtype=torch.int64) + fn = lambda: call([arg0_1, arg1_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/ir/cirnibisghsrxkpyxmbfn3vm524hpv5prkr6bklqnyd236xctufm.py b/SpecForge-ext/cache/compiled_kernels/ir/cirnibisghsrxkpyxmbfn3vm524hpv5prkr6bklqnyd236xctufm.py new file mode 100644 index 0000000000000000000000000000000000000000..cdf4236f721e0a1dc619a9c4a61e96f5200c0f1a --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/ir/cirnibisghsrxkpyxmbfn3vm524hpv5prkr6bklqnyd236xctufm.py @@ -0,0 +1,552 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'in_ptr9': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks0, 128*ks0, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks0, 128*ks0, 128, 1 + + ZQ = 8 + HQ = 32 + Q_LEN = 2048 + ZKV = 8 + KV_LEN = ks0 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 8 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 16*ks1 + stride_kv_idx_m = ks1 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 262144*idx_hq + 8388608*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m + 8388608*idx_zq, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr9 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/iw/80cab71d995e2d820124a84d206a1829a167fc71775679064f219ffa7fb5af5d.best_config b/SpecForge-ext/cache/compiled_kernels/iw/80cab71d995e2d820124a84d206a1829a167fc71775679064f219ffa7fb5af5d.best_config new file mode 100644 index 0000000000000000000000000000000000000000..422e1afda877306872879bb2d038c5a4e486fa13 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/iw/80cab71d995e2d820124a84d206a1829a167fc71775679064f219ffa7fb5af5d.best_config @@ -0,0 +1 @@ +{"XBLOCK": 1, "num_warps": 2, "num_stages": 1, "configs_hash": "6fcabd0411a839b7b5d117b5e6638bd1b5d7bc3379312c678d803859f08278a9", "found_by_coordesc": false, "time_taken_ms": 27, "triton_cache_hash": "NFXDDTZIEQAGR5JBWWCXV73QZILXJMIVVJVPZH3CHAIULDAND5UQ"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/iw/ciwoxk7cuonocxkjitlvfvf5jppmr2duv6vgwzkwaw4xszgcaf5m.py b/SpecForge-ext/cache/compiled_kernels/iw/ciwoxk7cuonocxkjitlvfvf5jppmr2duv6vgwzkwaw4xszgcaf5m.py new file mode 100644 index 0000000000000000000000000000000000000000..d168e332d9a93e5a9bf862668bde222f0af81d78 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/iw/ciwoxk7cuonocxkjitlvfvf5jppmr2duv6vgwzkwaw4xszgcaf5m.py @@ -0,0 +1,86 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 32, 'r0_': 16}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr4': '*i32', 'out_ptr5': '*i32', 'out_ptr6': '*i32', 'out_ptr7': '*i32', 'out_ptr8': '*i32', 'out_ptr9': '*i32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2', 'mutated_arg_names': ['out_ptr7', 'out_ptr9'], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2(in_ptr0, out_ptr4, out_ptr5, out_ptr6, out_ptr7, out_ptr8, out_ptr9, xnumel, r0_numel, XBLOCK : tl.constexpr): + xnumel = 32 + r0_numel = 16 + R0_BLOCK: tl.constexpr = 16 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + x0 = xindex + tmp0 = tl.load(in_ptr0 + (r0_1 + 16*x0), xmask, other=0.0) + tmp1 = tl.full([1, 1], 0, tl.int64) + tmp2 = tmp0 > tmp1 + tmp3 = tl.full([1, 1], 16384, tl.int64) + tmp4 = tmp0 < tmp3 + tmp5 = tmp2 & tmp4 + tmp6 = tmp5.to(tl.int8) + tmp7 = tmp6.to(tl.int32) + tmp8 = r0_1 + tmp9 = tmp8.to(tl.int16) + tmp10 = tl.broadcast_to(tmp7, [XBLOCK, R0_BLOCK]) + tmp11 = tl.broadcast_to(tmp9, [XBLOCK, R0_BLOCK]) + tmp12, tmp13, = triton_helpers.sort_with_index(tmp10, tmp11, None, 1, stable=True, descending=True) + tmp14 = tmp0 == tmp3 + tmp15 = tmp14.to(tl.int8) + tmp16 = tmp15.to(tl.int32) + tmp17 = tl.broadcast_to(tmp16, [XBLOCK, R0_BLOCK]) + tmp18, tmp19, = triton_helpers.sort_with_index(tmp17, tmp11, None, 1, stable=True, descending=True) + tmp20 = tmp7.to(tl.int64) + tmp21 = tl.broadcast_to(tmp20, [XBLOCK, R0_BLOCK]) + tmp23 = tl.where(xmask, tmp21, 0) + tmp24 = tl.sum(tmp23, 1)[:, None].to(tl.int64) + tmp25 = tmp16.to(tl.int64) + tmp26 = tl.broadcast_to(tmp25, [XBLOCK, R0_BLOCK]) + tmp28 = tl.where(xmask, tmp26, 0) + tmp29 = tl.sum(tmp28, 1)[:, None].to(tl.int64) + tmp30 = tmp24.to(tl.int32) + tmp31 = tmp29.to(tl.int32) + tmp32 = tmp13.to(tl.int64) + tmp33 = tmp32.to(tl.int32) + tmp34 = tmp8 < tmp30 + tmp35 = tl.full([1, 1], 16, tl.int32) + tmp36 = tl.where(tmp34, tmp33, tmp35) + tmp37 = tl.full([XBLOCK, R0_BLOCK], 17, tl.int32) + tmp38 = tmp36 + tmp37 + tmp39 = tmp36 < 0 + tmp40 = tl.where(tmp39, tmp38, tmp36) + tl.device_assert(((0 <= tmp40) & (tmp40 < 17)) | ~(xmask), "index out of bounds: 0 <= tmp40 < 17") + tmp42 = tl.full([1, 1], 1, tl.int32) + tmp43 = tmp19.to(tl.int64) + tmp44 = tmp43.to(tl.int32) + tmp45 = tmp8 < tmp31 + tmp46 = tl.where(tmp45, tmp44, tmp35) + tmp47 = tmp46 + tmp37 + tmp48 = tmp46 < 0 + tmp49 = tl.where(tmp48, tmp47, tmp46) + tl.device_assert(((0 <= tmp49) & (tmp49 < 17)) | ~(xmask), "index out of bounds: 0 <= tmp49 < 17") + tl.store(out_ptr4 + (x0), tmp30, xmask) + tl.store(out_ptr5 + (x0), tmp31, xmask) + tl.store(out_ptr6 + (r0_1 + 16*x0), tmp33, xmask) + tl.store(out_ptr7 + (tl.broadcast_to(tmp40 + 17*x0, [XBLOCK, R0_BLOCK])), tmp42, xmask) + tl.store(out_ptr8 + (r0_1 + 16*x0), tmp44, xmask) + tl.store(out_ptr9 + (tl.broadcast_to(tmp49 + 17*x0, [XBLOCK, R0_BLOCK])), tmp42, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/iz/cizft6bshl3hhcm2jld522l7ngdnfv72p4dso6fqgtuu4wcga7a7.py b/SpecForge-ext/cache/compiled_kernels/iz/cizft6bshl3hhcm2jld522l7ngdnfv72p4dso6fqgtuu4wcga7a7.py new file mode 100644 index 0000000000000000000000000000000000000000..943e67df50d055126874796f73fda4f1b2c4cd0b --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/iz/cizft6bshl3hhcm2jld522l7ngdnfv72p4dso6fqgtuu4wcga7a7.py @@ -0,0 +1,1065 @@ +# AOT ID: ['9_backward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4jnlpbb32eopbc2caystnepiaizyinwoncir73za7sf3sijadk.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] +# Source node to ATen node mapping: +# Graph fragment: +# %getitem : Tensor "bf16[2, 32, 2048, 128][8388608, 128, 4096, 1]cuda:1" = PlaceHolder[target=getitem] +# %tangents_1 : Tensor "bf16[2, 32, 2048, 128][8388608, 262144, 128, 1]cuda:1" = PlaceHolder[target=tangents_1] +# %buf0 : Tensor "bf16[2, 32, 2048][65536, 2048, 1]cuda:1" = PlaceHolder[target=buf0] +# %full_default : Tensor "f32[2, 32, 2048][65536, 2048, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 32, 2048], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:1, pin_memory: False}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_1, %primals_3, %primals_5, %getitem, %getitem_1, %tangents_1, %full_default, %fw_graph0, %joint_graph0, (2048, %primals_8, %primals_9, %primals_7, %primals_11, %primals_13, %primals_15, %primals_17, %primals_19, %primals_21, 128, 128, %mask_graph0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_10,)), kwargs = {}) +# return %buf0,%buf1 +triton_red_fused_zeros_0 = async_compile.triton('triton_red_fused_zeros_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 131072, 'r0_': 128}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr1': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 1048576, 'r0_': 67108864}} +) +@triton.jit +def triton_red_fused_zeros_0(in_ptr0, in_ptr1, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 131072 + r0_numel = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % 2048) + x1 = ((xindex // 2048) % 32) + x2 = xindex // 65536 + x4 = xindex + _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_3 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_3 + 128*x1 + 4096*x0 + 8388608*x2), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_3 + 128*x4), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = _tmp4 + tmp3 + _tmp4 = tl.where(r0_mask, tmp5, _tmp4) + tmp4 = tl.sum(_tmp4, 1)[:, None] + tmp6 = tmp4.to(tl.float32) + tmp7 = 0.0 + tmp8 = tmp6 - tmp7 + tl.store(out_ptr1 + (x4), tmp8, None) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/w6/cw664dpfyeegmrywhhjvuy7m6agln3vwpw6a655vqz2zfjueqy6w.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] +# Source node to ATen node mapping: +# Graph fragment: +# %primals_1 : Tensor "bf16[2, 32, 2048, 128][8388608, 128, 4096, 1]cuda:1" = PlaceHolder[target=primals_1] +# %primals_3 : Tensor "bf16[2, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:1" = PlaceHolder[target=primals_3] +# %primals_5 : Tensor "bf16[2, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:1" = PlaceHolder[target=primals_5] +# %getitem_1 : Tensor "f32[2, 32, 2048][65536, 2048, 1]cuda:1" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[2, 32, 2048][65536, 2048, 1]cuda:1" = PlaceHolder[target=buf1] +# %tangents_1 : Tensor "bf16[2, 32, 2048, 128][8388608, 262144, 128, 1]cuda:1" = PlaceHolder[target=tangents_1] +# %getitem_3 : Tensor "bf16[2, 32, 2048, 128][8388608, 128, 4096, 1]cuda:1" = PlaceHolder[target=getitem_3] +# %getitem_5 : Tensor "bf16[2, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:1" = PlaceHolder[target=getitem_5] +# %primals_9 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:1" = PlaceHolder[target=primals_9] +# %primals_7 : Tensor "i32[2, 1, 16, s72][16*s72, 16*s72, s72, 1]cuda:1" = PlaceHolder[target=primals_7] +# %primals_15 : Tensor "i32[2, 1, s56][s56, s56, 1]cuda:1" = PlaceHolder[target=primals_15] +# %primals_17 : Tensor "i32[2, 1, s84, 16][16*s84, 16*s84, 16, 1]cuda:1" = PlaceHolder[target=primals_17] +# %primals_11 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:1" = PlaceHolder[target=primals_11] +# %primals_13 : Tensor "i32[2, 1, 16, s4][16*s4, 16*s4, s4, 1]cuda:1" = PlaceHolder[target=primals_13] +# %primals_19 : Tensor "i32[2, 1, s99][s99, s99, 1]cuda:1" = PlaceHolder[target=primals_19] +# %primals_21 : Tensor "i32[2, 1, s6, 16][16*s6, 16*s6, 16, 1]cuda:1" = PlaceHolder[target=primals_21] +# %primals_10 : Tensor "i64[2][1]cuda:1" = PlaceHolder[target=primals_10] +# %full_default : Tensor "f32[2, 32, 2048][65536, 2048, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 32, 2048], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:1, pin_memory: False}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_1, %primals_3, %primals_5, %getitem, %getitem_1, %tangents_1, %full_default, %fw_graph0, %joint_graph0, (2048, %primals_8, %primals_9, %primals_7, %primals_11, %primals_13, %primals_15, %primals_17, %primals_19, %primals_21, 128, 128, %mask_graph0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_10,)), kwargs = {}) +# return %getitem_4 +triton_tem_fused_zeros_1 = async_compile.triton('triton_tem_fused_zeros_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_zeros_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks0, 128*ks0, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks0, 128*ks0, 128, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 8388608, 262144, 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 8388608, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks0, 128*ks0, 128, 1 + + ZQ = 2 + HQ = 32 + HKV = 8 + Q_LEN = 2048 + ZKV = 2 + KV_LEN = ks0 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 2 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 16*ks1 + stride_kv_idx_m = ks1 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = ks2 + stride_q_idx_h = 16*ks3 + stride_q_idx_n = 16 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks0 + 1024*off_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = ks0 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr16 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp39 = (ds) + grad_scores = tmp39 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = ks0 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp40 = (qkT) + post_mod_scores = tmp40 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp41 = tl.full([1], False, tl.int1) + tmp42 = (m) + tmp43 = (n) + tmp44 = tmp42 >= tmp43 + tmp45 = tmp43.to(tl.int64) + tmp46 = (off_z) + tmp47 = tl.load(in_ptr16 + tmp46) + tmp48 = tmp45 < tmp47 + tmp49 = tmp42.to(tl.int64) + tmp50 = tmp49 < tmp47 + tmp51 = tmp48 & tmp50 + tmp52 = tmp44 & tmp51 + tmp53 = tmp41 | tmp52 + tmp54 = tl.full([1], 2048, tl.int32) + tmp55 = tmp43 >= tmp54 + tmp56 = (tmp43 % tmp54) + tmp57 = tl.full([1], 0, tl.int32) + tmp58 = tmp56 != tmp57 + tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0 + tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0 + tmp61 = tmp59 != tmp60 + tmp62 = tmp58 & tmp61 + tmp63 = tmp56 + tmp54 + tmp64 = tl.where(tmp62, tmp63, tmp56) + tmp65 = tmp64.to(tl.int64) + tmp66 = tmp65 < tmp47 + tmp67 = tmp55 & tmp66 + tmp68 = tmp43 - tmp42 + tmp69 = (tmp68 % tmp54) + tmp70 = tmp69 != tmp57 + tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0 + tmp72 = tmp71 != tmp60 + tmp73 = tmp70 & tmp72 + tmp74 = tmp69 + tmp54 + tmp75 = tl.where(tmp73, tmp74, tmp69) + tmp76 = tmp75 == tmp57 + tmp77 = tmp67 & tmp76 + tmp78 = tmp53 | tmp77 + mask_mod_output = tmp78 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp79 = (dsT) + grad_scores = tmp79 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_8, primals_6, primals_12, primals_14, primals_16, primals_18, primals_20, primals_1, primals_3, primals_5, primals_7, primals_9, primals_10, primals_11, primals_13, primals_15, primals_17, primals_19, primals_21, getitem, getitem_1, tangents_1 = args + args.clear() + s0 = primals_8 + s72 = primals_6 + s4 = primals_12 + s56 = primals_14 + s84 = primals_16 + s99 = primals_18 + s6 = primals_20 + assert_size_stride(primals_1, (2, 32, 2048, 128), (8388608, 128, 4096, 1)) + assert_size_stride(primals_3, (2, 8, s0, 128), (1024*s0, 128*s0, 128, 1)) + assert_size_stride(primals_5, (2, 8, s0, 128), (1024*s0, 128*s0, 128, 1)) + assert_size_stride(primals_7, (2, 1, 16, s72), (16*s72, 16*s72, s72, 1)) + assert_size_stride(primals_9, (2, 1, 16), (16, 16, 1)) + assert_size_stride(primals_10, (2, ), (1, )) + assert_size_stride(primals_11, (2, 1, 16), (16, 16, 1)) + assert_size_stride(primals_13, (2, 1, 16, s4), (16*s4, 16*s4, s4, 1)) + assert_size_stride(primals_15, (2, 1, s56), (s56, s56, 1)) + assert_size_stride(primals_17, (2, 1, s84, 16), (16*s84, 16*s84, 16, 1)) + assert_size_stride(primals_19, (2, 1, s99), (s99, s99, 1)) + assert_size_stride(primals_21, (2, 1, s6, 16), (16*s6, 16*s6, 16, 1)) + assert_size_stride(getitem, (2, 32, 2048, 128), (8388608, 128, 4096, 1)) + assert_size_stride(getitem_1, (2, 32, 2048), (65536, 2048, 1)) + assert_size_stride(tangents_1, (2, 32, 2048, 128), (8388608, 262144, 128, 1)) + with torch.cuda._DeviceGuard(1): + torch.cuda.set_device(1) + buf1 = empty_strided_cuda((2, 32, 2048), (65536, 2048, 1), torch.float32) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] + stream1 = get_raw_stream(1) + triton_red_fused_zeros_0.run(getitem, tangents_1, buf1, 131072, 128, stream=stream1) + del getitem + buf3 = empty_strided_cuda((2, 32, 2048, 128), (8388608, 128, 4096, 1), torch.bfloat16) + buf4 = empty_strided_cuda((2, 8, s0, 128), (1024*s0, 128*s0, 128, 1), torch.bfloat16) + buf5 = empty_strided_cuda((2, 8, s0, 128), (1024*s0, 128*s0, 128, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] + stream1 = get_raw_stream(1) + triton_tem_fused_zeros_1.run(primals_1, primals_3, primals_5, getitem_1, buf1, tangents_1, buf3, buf4, primals_9, primals_7, primals_15, primals_17, primals_11, primals_13, primals_19, primals_21, primals_10, buf5, s0, s72, s56, s84, 64 + ((127 + s0) // 128), 2, 8, stream=stream1) + del buf1 + del getitem_1 + del primals_1 + del primals_10 + del primals_11 + del primals_13 + del primals_15 + del primals_17 + del primals_19 + del primals_21 + del primals_3 + del primals_5 + del primals_7 + del primals_9 + del tangents_1 + return (buf3, None, buf5, None, buf4, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_8 = 4096 + primals_6 = 32 + primals_12 = 32 + primals_14 = 32 + primals_16 = 32 + primals_18 = 32 + primals_20 = 32 + primals_1 = rand_strided((2, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:1', dtype=torch.bfloat16) + primals_3 = rand_strided((2, 8, 4096, 128), (4194304, 524288, 128, 1), device='cuda:1', dtype=torch.bfloat16) + primals_5 = rand_strided((2, 8, 4096, 128), (4194304, 524288, 128, 1), device='cuda:1', dtype=torch.bfloat16) + primals_7 = rand_strided((2, 1, 16, 32), (512, 512, 32, 1), device='cuda:1', dtype=torch.int32) + primals_9 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:1', dtype=torch.int32) + primals_10 = rand_strided((2, ), (1, ), device='cuda:1', dtype=torch.int64) + primals_11 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:1', dtype=torch.int32) + primals_13 = rand_strided((2, 1, 16, 32), (512, 512, 32, 1), device='cuda:1', dtype=torch.int32) + primals_15 = rand_strided((2, 1, 32), (32, 32, 1), device='cuda:1', dtype=torch.int32) + primals_17 = rand_strided((2, 1, 32, 16), (512, 512, 16, 1), device='cuda:1', dtype=torch.int32) + primals_19 = rand_strided((2, 1, 32), (32, 32, 1), device='cuda:1', dtype=torch.int32) + primals_21 = rand_strided((2, 1, 32, 16), (512, 512, 16, 1), device='cuda:1', dtype=torch.int32) + getitem = rand_strided((2, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:1', dtype=torch.bfloat16) + getitem_1 = rand_strided((2, 32, 2048), (65536, 2048, 1), device='cuda:1', dtype=torch.float32) + tangents_1 = rand_strided((2, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:1', dtype=torch.bfloat16) + fn = lambda: call([primals_8, primals_6, primals_12, primals_14, primals_16, primals_18, primals_20, primals_1, primals_3, primals_5, primals_7, primals_9, primals_10, primals_11, primals_13, primals_15, primals_17, primals_19, primals_21, getitem, getitem_1, tangents_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/jd/cjdjlzpgoi4d6trbdutm2zq3o2chrt2dek76dhf4dqdw5ppk4vke.py b/SpecForge-ext/cache/compiled_kernels/jd/cjdjlzpgoi4d6trbdutm2zq3o2chrt2dek76dhf4dqdw5ppk4vke.py new file mode 100644 index 0000000000000000000000000000000000000000..5f57a773ff4aa673212b60559fb732e418a5f0ba --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/jd/cjdjlzpgoi4d6trbdutm2zq3o2chrt2dek76dhf4dqdw5ppk4vke.py @@ -0,0 +1,835 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32', 'ks6': 'i32', 'ks7': 'i32', 'ks8': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention_backward(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128*ks1, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128*ks1, 128, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128*ks1, 128, 1 + + ZQ = 2 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 2 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 2 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = ks5 + stride_q_idx_h = ks6*ks7 + stride_q_idx_n = ks6 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr16 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = ks8 + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp39 = (ds) + grad_scores = tmp39 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp40 = (qkT) + post_mod_scores = tmp40 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp41 = tl.full([1], False, tl.int1) + tmp42 = (m) + tmp43 = (n) + tmp44 = tmp42 >= tmp43 + tmp45 = tmp43.to(tl.int64) + tmp46 = (off_z) + tmp47 = tl.load(in_ptr16 + tmp46) + tmp48 = tmp45 < tmp47 + tmp49 = tmp42.to(tl.int64) + tmp50 = tmp49 < tmp47 + tmp51 = tmp48 & tmp50 + tmp52 = tmp44 & tmp51 + tmp53 = tmp41 | tmp52 + tmp54 = ks8 + tmp55 = tmp43 >= tmp54 + tmp56 = (tmp43 % tmp54) + tmp57 = tl.full([1], 0, tl.int32) + tmp58 = tmp56 != tmp57 + tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0 + tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0 + tmp61 = tmp59 != tmp60 + tmp62 = tmp58 & tmp61 + tmp63 = tmp56 + tmp54 + tmp64 = tl.where(tmp62, tmp63, tmp56) + tmp65 = tmp64.to(tl.int64) + tmp66 = tmp65 < tmp47 + tmp67 = tmp55 & tmp66 + tmp68 = tmp43 - tmp42 + tmp69 = (tmp68 % tmp54) + tmp70 = tmp69 != tmp57 + tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0 + tmp72 = tmp71 != tmp60 + tmp73 = tmp70 & tmp72 + tmp74 = tmp69 + tmp54 + tmp75 = tl.where(tmp73, tmp74, tmp69) + tmp76 = tmp75 == tmp57 + tmp77 = tmp67 & tmp76 + tmp78 = tmp53 | tmp77 + mask_mod_output = tmp78 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp79 = (dsT) + grad_scores = tmp79 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/jg/cjgevvsghujpp26j5phkt5k4zrbjrbu2qnlybwiarlyz3uyulvee.py b/SpecForge-ext/cache/compiled_kernels/jg/cjgevvsghujpp26j5phkt5k4zrbjrbu2qnlybwiarlyz3uyulvee.py new file mode 100644 index 0000000000000000000000000000000000000000..b6afbaaf6dd6d11c316be76c567958610787b4f7 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/jg/cjgevvsghujpp26j5phkt5k4zrbjrbu2qnlybwiarlyz3uyulvee.py @@ -0,0 +1,835 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention_backward(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 2097152, 262144, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 2097152, 262144, 128, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 8388608, 262144, 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 8388608, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 2097152, 262144, 128, 1 + + ZQ = 8 + HQ = 32 + HKV = 8 + Q_LEN = 2048 + ZKV = 8 + KV_LEN = 2048 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 8 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 256 + stride_kv_idx_m = 16 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 16 + stride_q_idx_h = 256 + stride_q_idx_n = 16 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 262144*off_hkv + 2097152*off_zq + tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = 2048 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr16 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp39 = (ds) + grad_scores = tmp39 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = 2048 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp40 = (qkT) + post_mod_scores = tmp40 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp41 = tl.full([1], False, tl.int1) + tmp42 = (m) + tmp43 = (n) + tmp44 = tmp42 >= tmp43 + tmp45 = tmp43.to(tl.int64) + tmp46 = (off_z) + tmp47 = tl.load(in_ptr16 + tmp46) + tmp48 = tmp45 < tmp47 + tmp49 = tmp42.to(tl.int64) + tmp50 = tmp49 < tmp47 + tmp51 = tmp48 & tmp50 + tmp52 = tmp44 & tmp51 + tmp53 = tmp41 | tmp52 + tmp54 = tl.full([1], 2048, tl.int32) + tmp55 = tmp43 >= tmp54 + tmp56 = (tmp43 % tmp54) + tmp57 = tl.full([1], 0, tl.int32) + tmp58 = tmp56 != tmp57 + tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0 + tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0 + tmp61 = tmp59 != tmp60 + tmp62 = tmp58 & tmp61 + tmp63 = tmp56 + tmp54 + tmp64 = tl.where(tmp62, tmp63, tmp56) + tmp65 = tmp64.to(tl.int64) + tmp66 = tmp65 < tmp47 + tmp67 = tmp55 & tmp66 + tmp68 = tmp43 - tmp42 + tmp69 = (tmp68 % tmp54) + tmp70 = tmp69 != tmp57 + tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0 + tmp72 = tmp71 != tmp60 + tmp73 = tmp70 & tmp72 + tmp74 = tmp69 + tmp54 + tmp75 = tl.where(tmp73, tmp74, tmp69) + tmp76 = tmp75 == tmp57 + tmp77 = tmp67 & tmp76 + tmp78 = tmp53 | tmp77 + mask_mod_output = tmp78 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp79 = (dsT) + grad_scores = tmp79 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/jj/cjj7qa2aovrx6hv5lc745ia43xxssp7crbtxcecmmsl3zznuefjs.py b/SpecForge-ext/cache/compiled_kernels/jj/cjj7qa2aovrx6hv5lc745ia43xxssp7crbtxcecmmsl3zznuefjs.py new file mode 100644 index 0000000000000000000000000000000000000000..aa01bcd6e329ff98556f016debd46e54b264e6a7 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/jj/cjj7qa2aovrx6hv5lc745ia43xxssp7crbtxcecmmsl3zznuefjs.py @@ -0,0 +1,552 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'in_ptr9': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128*ks1, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128*ks1, 128, 1 + + ZQ = 2 + HQ = 32 + Q_LEN = ks0 + ZKV = 2 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 2 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m + 4096*idx_zq*ks0, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr9 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = ks5 + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/jj/cjjjlltuusnaz2p35p7wn4v3tircemd2by2n4yrsnt3qsaoy4pmp.py b/SpecForge-ext/cache/compiled_kernels/jj/cjjjlltuusnaz2p35p7wn4v3tircemd2by2n4yrsnt3qsaoy4pmp.py new file mode 100644 index 0000000000000000000000000000000000000000..f0edf17676a7971eeb36a69888bb565d3543ceac --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/jj/cjjjlltuusnaz2p35p7wn4v3tircemd2by2n4yrsnt3qsaoy4pmp.py @@ -0,0 +1,57 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 16384, 'r0_': 262144}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_out_ptr0': '*i64', 'in_ptr0': '*bf16', 'in_ptr1': '*i1', 'in_ptr2': '*i64', 'xnumel': 'i64', 'r0_numel': 'i64', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 16384 + r0_numel = 151936 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0).to(tl.int64) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None].to(tl.int64) + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :].to(tl.int64) + rbase = r0_base + x0 = xindex + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 9223372036854775807, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 151936*x0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tmp11 = tl.load(in_ptr2 + (x0), None, eviction_policy='evict_last') + tmp3 = tl.full([XBLOCK, 1], 151936, tl.int32) + tmp4 = tmp2 + tmp3 + tmp5 = tmp2 < 0 + tmp6 = tl.where(tmp5, tmp4, tmp2) + tl.device_assert((0 <= tmp6) & (tmp6 < 151936), "index out of bounds: 0 <= tmp6 < 151936") + tmp8 = tl.load(in_ptr1 + (tmp6), None, eviction_policy='evict_last').to(tl.int1) + tmp9 = tmp8.to(tl.int32) + tmp10 = tmp9.to(tl.int64) + tmp12 = tmp10 * tmp11 + tl.debug_barrier() + tl.store(in_out_ptr0 + (x0), tmp12, None) diff --git a/SpecForge-ext/cache/compiled_kernels/jl/cjle7kfqci6h5nmotf7nendcz46bovinrchlquazmke3gh5aki2t.py b/SpecForge-ext/cache/compiled_kernels/jl/cjle7kfqci6h5nmotf7nendcz46bovinrchlquazmke3gh5aki2t.py new file mode 100644 index 0000000000000000000000000000000000000000..98a7528e6a5223d675b5de91459c5dd309df36b8 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/jl/cjle7kfqci6h5nmotf7nendcz46bovinrchlquazmke3gh5aki2t.py @@ -0,0 +1,835 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention_backward(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 2097152, 262144, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 2097152, 262144, 128, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 8388608, 262144, 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 8388608, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 2097152, 262144, 128, 1 + + ZQ = 2 + HQ = 32 + HKV = 8 + Q_LEN = 2048 + ZKV = 2 + KV_LEN = 2048 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 2 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 256 + stride_kv_idx_m = 16 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 16 + stride_q_idx_h = 256 + stride_q_idx_n = 16 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 262144*off_hkv + 2097152*off_zq + tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = 2048 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr16 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp39 = (ds) + grad_scores = tmp39 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = 2048 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp40 = (qkT) + post_mod_scores = tmp40 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp41 = tl.full([1], False, tl.int1) + tmp42 = (m) + tmp43 = (n) + tmp44 = tmp42 >= tmp43 + tmp45 = tmp43.to(tl.int64) + tmp46 = (off_z) + tmp47 = tl.load(in_ptr16 + tmp46) + tmp48 = tmp45 < tmp47 + tmp49 = tmp42.to(tl.int64) + tmp50 = tmp49 < tmp47 + tmp51 = tmp48 & tmp50 + tmp52 = tmp44 & tmp51 + tmp53 = tmp41 | tmp52 + tmp54 = tl.full([1], 2048, tl.int32) + tmp55 = tmp43 >= tmp54 + tmp56 = (tmp43 % tmp54) + tmp57 = tl.full([1], 0, tl.int32) + tmp58 = tmp56 != tmp57 + tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0 + tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0 + tmp61 = tmp59 != tmp60 + tmp62 = tmp58 & tmp61 + tmp63 = tmp56 + tmp54 + tmp64 = tl.where(tmp62, tmp63, tmp56) + tmp65 = tmp64.to(tl.int64) + tmp66 = tmp65 < tmp47 + tmp67 = tmp55 & tmp66 + tmp68 = tmp43 - tmp42 + tmp69 = (tmp68 % tmp54) + tmp70 = tmp69 != tmp57 + tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0 + tmp72 = tmp71 != tmp60 + tmp73 = tmp70 & tmp72 + tmp74 = tmp69 + tmp54 + tmp75 = tl.where(tmp73, tmp74, tmp69) + tmp76 = tmp75 == tmp57 + tmp77 = tmp67 & tmp76 + tmp78 = tmp53 | tmp77 + mask_mod_output = tmp78 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp79 = (dsT) + grad_scores = tmp79 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/jt/742a50b34481c726a5fa27a3d74847fb36d6629200c59ca289b03ace32c2ec32.best_config b/SpecForge-ext/cache/compiled_kernels/jt/742a50b34481c726a5fa27a3d74847fb36d6629200c59ca289b03ace32c2ec32.best_config new file mode 100644 index 0000000000000000000000000000000000000000..8920a6ebe9dac1a267cf3c5b5085d70019ad08a3 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/jt/742a50b34481c726a5fa27a3d74847fb36d6629200c59ca289b03ace32c2ec32.best_config @@ -0,0 +1 @@ +{"XBLOCK": 1, "num_warps": 2, "num_stages": 1, "configs_hash": "b6ac5ef64fddcad8fc8d2c05fa12424871fd9baa5a4158ff38ecebbafb55a4b1", "found_by_coordesc": false, "time_taken_ms": 36, "triton_cache_hash": "MMGM2ESHRXPRFAROBBDYKTZUJ2JVVKU2TB5DVA3EL4OF2SOELPMQ"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/jt/cjtvcwf2usd2xtmomczzp2mogjtphmsmqtta6fceusmpjkttojhx.py b/SpecForge-ext/cache/compiled_kernels/jt/cjtvcwf2usd2xtmomczzp2mogjtphmsmqtta6fceusmpjkttojhx.py new file mode 100644 index 0000000000000000000000000000000000000000..80fc4f0b1b07db3b01720ee56980aef0ac6083b1 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/jt/cjtvcwf2usd2xtmomczzp2mogjtphmsmqtta6fceusmpjkttojhx.py @@ -0,0 +1,86 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 128, 'r0_': 16}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr4': '*i32', 'out_ptr5': '*i32', 'out_ptr6': '*i32', 'out_ptr7': '*i32', 'out_ptr8': '*i32', 'out_ptr9': '*i32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2', 'mutated_arg_names': ['out_ptr7', 'out_ptr9'], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2(in_ptr0, out_ptr4, out_ptr5, out_ptr6, out_ptr7, out_ptr8, out_ptr9, xnumel, r0_numel, XBLOCK : tl.constexpr): + xnumel = 128 + r0_numel = 16 + R0_BLOCK: tl.constexpr = 16 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + x0 = xindex + tmp0 = tl.load(in_ptr0 + (r0_1 + 16*x0), xmask, other=0.0) + tmp1 = tl.full([1, 1], 0, tl.int64) + tmp2 = tmp0 > tmp1 + tmp3 = tl.full([1, 1], 16384, tl.int64) + tmp4 = tmp0 < tmp3 + tmp5 = tmp2 & tmp4 + tmp6 = tmp5.to(tl.int8) + tmp7 = tmp6.to(tl.int32) + tmp8 = r0_1 + tmp9 = tmp8.to(tl.int16) + tmp10 = tl.broadcast_to(tmp7, [XBLOCK, R0_BLOCK]) + tmp11 = tl.broadcast_to(tmp9, [XBLOCK, R0_BLOCK]) + tmp12, tmp13, = triton_helpers.sort_with_index(tmp10, tmp11, None, 1, stable=True, descending=True) + tmp14 = tmp0 == tmp3 + tmp15 = tmp14.to(tl.int8) + tmp16 = tmp15.to(tl.int32) + tmp17 = tl.broadcast_to(tmp16, [XBLOCK, R0_BLOCK]) + tmp18, tmp19, = triton_helpers.sort_with_index(tmp17, tmp11, None, 1, stable=True, descending=True) + tmp20 = tmp7.to(tl.int64) + tmp21 = tl.broadcast_to(tmp20, [XBLOCK, R0_BLOCK]) + tmp23 = tl.where(xmask, tmp21, 0) + tmp24 = tl.sum(tmp23, 1)[:, None].to(tl.int64) + tmp25 = tmp16.to(tl.int64) + tmp26 = tl.broadcast_to(tmp25, [XBLOCK, R0_BLOCK]) + tmp28 = tl.where(xmask, tmp26, 0) + tmp29 = tl.sum(tmp28, 1)[:, None].to(tl.int64) + tmp30 = tmp24.to(tl.int32) + tmp31 = tmp29.to(tl.int32) + tmp32 = tmp13.to(tl.int64) + tmp33 = tmp32.to(tl.int32) + tmp34 = tmp8 < tmp30 + tmp35 = tl.full([1, 1], 16, tl.int32) + tmp36 = tl.where(tmp34, tmp33, tmp35) + tmp37 = tl.full([XBLOCK, R0_BLOCK], 17, tl.int32) + tmp38 = tmp36 + tmp37 + tmp39 = tmp36 < 0 + tmp40 = tl.where(tmp39, tmp38, tmp36) + tl.device_assert(((0 <= tmp40) & (tmp40 < 17)) | ~(xmask), "index out of bounds: 0 <= tmp40 < 17") + tmp42 = tl.full([1, 1], 1, tl.int32) + tmp43 = tmp19.to(tl.int64) + tmp44 = tmp43.to(tl.int32) + tmp45 = tmp8 < tmp31 + tmp46 = tl.where(tmp45, tmp44, tmp35) + tmp47 = tmp46 + tmp37 + tmp48 = tmp46 < 0 + tmp49 = tl.where(tmp48, tmp47, tmp46) + tl.device_assert(((0 <= tmp49) & (tmp49 < 17)) | ~(xmask), "index out of bounds: 0 <= tmp49 < 17") + tl.store(out_ptr4 + (x0), tmp30, xmask) + tl.store(out_ptr5 + (x0), tmp31, xmask) + tl.store(out_ptr6 + (r0_1 + 16*x0), tmp33, xmask) + tl.store(out_ptr7 + (tl.broadcast_to(tmp40 + 17*x0, [XBLOCK, R0_BLOCK])), tmp42, xmask) + tl.store(out_ptr8 + (r0_1 + 16*x0), tmp44, xmask) + tl.store(out_ptr9 + (tl.broadcast_to(tmp49 + 17*x0, [XBLOCK, R0_BLOCK])), tmp42, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/jw/cjwg74y5z7pe3cojkminsnhtbxqz7bpqqf3e5xzm2zm25e4ac4d7.py b/SpecForge-ext/cache/compiled_kernels/jw/cjwg74y5z7pe3cojkminsnhtbxqz7bpqqf3e5xzm2zm25e4ac4d7.py new file mode 100644 index 0000000000000000000000000000000000000000..878bf254d5ddc212d4c72fd079f1c33ab825b16f --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/jw/cjwg74y5z7pe3cojkminsnhtbxqz7bpqqf3e5xzm2zm25e4ac4d7.py @@ -0,0 +1,62 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 32, 'r0_': 32}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'in_ptr1': '*i64', 'out_ptr1': '*i32', 'out_ptr2': '*i32', 'out_ptr3': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2', 'mutated_arg_names': ['out_ptr3'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2(in_ptr0, in_ptr1, out_ptr1, out_ptr2, out_ptr3, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 32 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp3 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp1 = tmp0.to(tl.int64) + tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) + tmp4 = _tmp3 + tmp2 + _tmp3 = tl.where(r0_mask & xmask, tmp4, _tmp3) + tmp3 = tl.sum(_tmp3, 1)[:, None] + tmp5 = tmp3.to(tl.int32) + tl.store(out_ptr1 + (x0), tmp5, xmask) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp6 = tl.load(in_ptr1 + (r0_1 + x0*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp7 = tmp6.to(tl.int32) + tmp8 = r0_1 + tmp9 = tmp8 < tmp5 + tmp10 = ks0 + tmp11 = tl.where(tmp9, tmp7, tmp10) + tmp12 = 1 + ks0 + tmp13 = tmp11 + tmp12 + tmp14 = tmp11 < 0 + tmp15 = tl.where(tmp14, tmp13, tmp11) + tl.device_assert(((0 <= tmp15) & (tmp15 < 1 + (triton_helpers.div_floor_integer(127 + ks1, 128)))) | ~(r0_mask & xmask), "index out of bounds: 0 <= tmp15 < 1 + (triton_helpers.div_floor_integer(127 + ks1, 128))") + tmp17 = tl.full([1, 1], 1, tl.int32) + tl.store(out_ptr2 + (r0_1 + x0*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp7, r0_mask & xmask) + tl.store(out_ptr3 + (tl.broadcast_to(tmp15 + x0 + ks0*x0, [XBLOCK, R0_BLOCK])), tmp17, r0_mask & xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/jw/cjwrwqw6c6d6srrr43vnpfdvv7frmh2p6lorz4nnu3su2444db7g.py b/SpecForge-ext/cache/compiled_kernels/jw/cjwrwqw6c6d6srrr43vnpfdvv7frmh2p6lorz4nnu3su2444db7g.py new file mode 100644 index 0000000000000000000000000000000000000000..0d20856b02278d51a42e3900a87379e14d230ced --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/jw/cjwrwqw6c6d6srrr43vnpfdvv7frmh2p6lorz4nnu3su2444db7g.py @@ -0,0 +1,47 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*i64', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_argmax_1(in_ptr0, out_ptr0, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % ks0) + x1 = xindex // ks0 + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + x3 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_2 + 32000*x0 + ks1*x1), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask & xmask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask & xmask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tl.store(out_ptr0 + (x3), tmp2, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/k6/ck6df4wskigwgr7cy5zmyyprg45oqai7hx5qp3rxmioxi45hw42n.py b/SpecForge-ext/cache/compiled_kernels/k6/ck6df4wskigwgr7cy5zmyyprg45oqai7hx5qp3rxmioxi45hw42n.py new file mode 100644 index 0000000000000000000000000000000000000000..a2e0ac5567c514797a13f4ae911b9016508b1f3c --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/k6/ck6df4wskigwgr7cy5zmyyprg45oqai7hx5qp3rxmioxi45hw42n.py @@ -0,0 +1,168 @@ +# AOT ID: ['10_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/7b/c7beunetfioua3igdykpr7pduoynrh24iikmwvjs76gpa6omyxvs.py +# Topologically Sorted Source Nodes: [target_max_token, target_mask, getitem_1, target_mask_1, position_mask], Original ATen: [aten.argmax, aten.index, aten.unsqueeze, aten._to_copy, aten.mul] +# Source node to ATen node mapping: +# getitem_1 => unsqueeze +# position_mask => mul_2 +# target_mask => index +# target_mask_1 => convert_element_type +# target_max_token => argmax +# Graph fragment: +# %arg1_1 : Tensor "bf16[2, s14, 151936][151936*s14, 151936, 1]cuda:6" = PlaceHolder[target=arg1_1] +# %argmax : Tensor "i64[2, s14][s14, 1]cuda:6" = PlaceHolder[target=argmax] +# %arg2_1 : Tensor "b8[151936][1]cuda:6" = PlaceHolder[target=arg2_1] +# %arg3_1 : Tensor "i64[2, s14, 1][s14, 1, 1]cuda:6" = PlaceHolder[target=arg3_1] +# %argmax : Tensor "i64[2, s14][s14, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.argmax.default](args = (%arg1_1, -1), kwargs = {}) +# %index : Tensor "b8[2, s14][s14, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg2_1, [%argmax]), kwargs = {}) +# %unsqueeze : Tensor "b8[2, s14, 1][s14, 1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index, 2), kwargs = {}) +# %convert_element_type : Tensor "i32[2, s14, 1][s14, 1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%unsqueeze, torch.int32), kwargs = {}) +# %mul_2 : Tensor "i64[2, s14, 1][s14, 1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type, %arg3_1), kwargs = {}) +# return %argmax,%mul_2 +triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0 = async_compile.triton('triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 262144}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_out_ptr0': '*i64', 'in_ptr0': '*bf16', 'in_ptr1': '*i1', 'in_ptr2': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 151936 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 151936*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask & xmask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask & xmask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tmp11 = tl.load(in_ptr2 + (x0), xmask, eviction_policy='evict_last') + tmp3 = tl.full([XBLOCK, 1], 151936, tl.int32) + tmp4 = tmp2 + tmp3 + tmp5 = tmp2 < 0 + tmp6 = tl.where(tmp5, tmp4, tmp2) + tl.device_assert(((0 <= tmp6) & (tmp6 < 151936)) | ~(xmask), "index out of bounds: 0 <= tmp6 < 151936") + tmp8 = tl.load(in_ptr1 + (tmp6), xmask, eviction_policy='evict_last').to(tl.int1) + tmp9 = tmp8.to(tl.int32) + tmp10 = tmp9.to(tl.int64) + tmp12 = tmp10 * tmp11 + tl.debug_barrier() + tl.store(in_out_ptr0 + (x0), tmp12, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1, arg2_1, arg3_1 = args + args.clear() + s24 = arg0_1 + arg1_1_size = arg1_1.size() + s14 = arg1_1_size[1] + assert_size_stride(arg1_1, (2, s14, 151936), (151936*s14, 151936, 1)) + assert_size_stride(arg2_1, (151936, ), (1, )) + assert_size_stride(arg3_1, (2, s14, 1), (s14, 1, 1)) + with torch.cuda._DeviceGuard(6): + torch.cuda.set_device(6) + buf0 = empty_strided_cuda((2, s14), (s14, 1), torch.int64) + buf1 = reinterpret_tensor(buf0, (2, s14, 1), (s14, 1, 1), 0); del buf0 # reuse + # Topologically Sorted Source Nodes: [target_max_token, target_mask, getitem_1, target_mask_1, position_mask], Original ATen: [aten.argmax, aten.index, aten.unsqueeze, aten._to_copy, aten.mul] + triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0_xnumel = 2*s14 + stream6 = get_raw_stream(6) + triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0.run(buf1, arg1_1, arg2_1, arg3_1, triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0_xnumel, 151936, stream=stream6) + del arg1_1 + del arg2_1 + del arg3_1 + return (buf1, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = 1488 + arg1_1 = rand_strided((2, 1488, 151936), (226080768, 151936, 1), device='cuda:6', dtype=torch.bfloat16) + arg2_1 = rand_strided((151936, ), (1, ), device='cuda:6', dtype=torch.bool) + arg3_1 = rand_strided((2, 1488, 1), (1488, 1, 1), device='cuda:6', dtype=torch.int64) + fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/k6/ck6kagbzevswaqcqv6y4poeutrof5seltiom2rraokkvdvi4duae.py b/SpecForge-ext/cache/compiled_kernels/k6/ck6kagbzevswaqcqv6y4poeutrof5seltiom2rraokkvdvi4duae.py new file mode 100644 index 0000000000000000000000000000000000000000..fc1f08dfbd6ee982a359ffc5baab2b6e75339bd7 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/k6/ck6kagbzevswaqcqv6y4poeutrof5seltiom2rraokkvdvi4duae.py @@ -0,0 +1,27 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 512}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr0': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_6', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused__to_copy_6(in_ptr0, out_ptr0, ks0, ks1, ks2, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = (xindex % ks0) + x1 = ((xindex // ks0) % ks1) + x2 = xindex // ks2 + tmp0 = tl.load(in_ptr0 + (x1 + x0*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1))) + x2*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1)))), xmask, eviction_policy='evict_last') + tmp1 = tmp0.to(tl.int32) + tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))) + x2*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1)))), tmp1, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/k6/fe6dc90c92713bfb1ec0da2712431165d134af58fb046888593d29113b005554.best_config b/SpecForge-ext/cache/compiled_kernels/k6/fe6dc90c92713bfb1ec0da2712431165d134af58fb046888593d29113b005554.best_config new file mode 100644 index 0000000000000000000000000000000000000000..0102fea510b9bf77ab661e714dfc816c066dc0d8 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/k6/fe6dc90c92713bfb1ec0da2712431165d134af58fb046888593d29113b005554.best_config @@ -0,0 +1 @@ +{"XBLOCK": 128, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 11, "triton_cache_hash": "IK5RT3JGLTF5PMMUH32NIWB2GXNU6R6CGIZSCRHU3I65YM226KDA"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/kk/806d3a3fb29d494c9b1d81bd47fff98385dddf055ee66718d6f0e86ebb32e252.best_config b/SpecForge-ext/cache/compiled_kernels/kk/806d3a3fb29d494c9b1d81bd47fff98385dddf055ee66718d6f0e86ebb32e252.best_config new file mode 100644 index 0000000000000000000000000000000000000000..26f34a32396bf93c323cd255a2cf49b0585d7f4b --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/kk/806d3a3fb29d494c9b1d81bd47fff98385dddf055ee66718d6f0e86ebb32e252.best_config @@ -0,0 +1 @@ +{"XBLOCK": 32, "R0_BLOCK": 16, "num_warps": 4, "num_stages": 1, "configs_hash": "21ad1ee516cd6d15e1fb8e88c10082cd54bef654f8a281c7d5ccd54b6509a685", "found_by_coordesc": false, "time_taken_ms": 28, "triton_cache_hash": "2HBOMUT44J5WFCUWYGRFAAS3HGVNDHLHT7HCSXUCAOIKU6XGJNTA"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/kk/ckklxdmte7z2y72grm5gwa42pnmzupkmzc3264kqgowlcrcmm53b.py b/SpecForge-ext/cache/compiled_kernels/kk/ckklxdmte7z2y72grm5gwa42pnmzupkmzc3264kqgowlcrcmm53b.py new file mode 100644 index 0000000000000000000000000000000000000000..71b6706f6afa28a219c246413c41280f15f31d68 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/kk/ckklxdmte7z2y72grm5gwa42pnmzupkmzc3264kqgowlcrcmm53b.py @@ -0,0 +1,43 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 32, 'r0_': 16}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr1': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_clone_slice_sum_transpose_5', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_clone_slice_sum_transpose_5(in_ptr0, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % ks0) + x1 = xindex // ks0 + _tmp3 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + x3 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (x0 + ks0*r0_2 + ks0*ks1*x1), r0_mask & xmask, eviction_policy='evict_last', other=0.0) + tmp1 = tmp0.to(tl.int64) + tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) + tmp4 = _tmp3 + tmp2 + _tmp3 = tl.where(r0_mask & xmask, tmp4, _tmp3) + tmp3 = tl.sum(_tmp3, 1)[:, None] + tmp5 = tmp3.to(tl.int32) + tl.store(out_ptr1 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp5, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/kk/ckkonpj4ig4m6kul577movgkkpytb6t5h6kpoun5efcbvgaje63a.py b/SpecForge-ext/cache/compiled_kernels/kk/ckkonpj4ig4m6kul577movgkkpytb6t5h6kpoun5efcbvgaje63a.py new file mode 100644 index 0000000000000000000000000000000000000000..2277d2efdeb06dbad40e025ed0077f112c247cc5 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/kk/ckkonpj4ig4m6kul577movgkkpytb6t5h6kpoun5efcbvgaje63a.py @@ -0,0 +1,552 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'in_ptr9': '*i64', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 2097152, 262144, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 2097152, 262144, 128, 1 + + ZQ = 8 + HQ = 32 + Q_LEN = 2048 + ZKV = 8 + KV_LEN = 2048 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 8 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 256 + stride_kv_idx_m = 16 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 262144*idx_hq + 8388608*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m + 8388608*idx_zq, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr9 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/la/clawugqo7lvmpspporbqixaqqea5cedplrnsa3b7zbpl6fjgqmml.py b/SpecForge-ext/cache/compiled_kernels/la/clawugqo7lvmpspporbqixaqqea5cedplrnsa3b7zbpl6fjgqmml.py new file mode 100644 index 0000000000000000000000000000000000000000..b1bf421dc7423b4c2fdc51ad16f7d195ac00c3e0 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/la/clawugqo7lvmpspporbqixaqqea5cedplrnsa3b7zbpl6fjgqmml.py @@ -0,0 +1,835 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_zeros_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 2097152, 262144, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 2097152, 262144, 128, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 8388608, 262144, 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 8388608, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 2097152, 262144, 128, 1 + + ZQ = 8 + HQ = 32 + HKV = 8 + Q_LEN = 2048 + ZKV = 8 + KV_LEN = 2048 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 8 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 256 + stride_kv_idx_m = 16 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 16 + stride_q_idx_h = 256 + stride_q_idx_n = 16 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 262144*off_hkv + 2097152*off_zq + tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = 2048 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr16 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp39 = (ds) + grad_scores = tmp39 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = 2048 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp40 = (qkT) + post_mod_scores = tmp40 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp41 = tl.full([1], False, tl.int1) + tmp42 = (m) + tmp43 = (n) + tmp44 = tmp42 >= tmp43 + tmp45 = tmp43.to(tl.int64) + tmp46 = (off_z) + tmp47 = tl.load(in_ptr16 + tmp46) + tmp48 = tmp45 < tmp47 + tmp49 = tmp42.to(tl.int64) + tmp50 = tmp49 < tmp47 + tmp51 = tmp48 & tmp50 + tmp52 = tmp44 & tmp51 + tmp53 = tmp41 | tmp52 + tmp54 = tl.full([1], 2048, tl.int32) + tmp55 = tmp43 >= tmp54 + tmp56 = (tmp43 % tmp54) + tmp57 = tl.full([1], 0, tl.int32) + tmp58 = tmp56 != tmp57 + tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0 + tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0 + tmp61 = tmp59 != tmp60 + tmp62 = tmp58 & tmp61 + tmp63 = tmp56 + tmp54 + tmp64 = tl.where(tmp62, tmp63, tmp56) + tmp65 = tmp64.to(tl.int64) + tmp66 = tmp65 < tmp47 + tmp67 = tmp55 & tmp66 + tmp68 = tmp43 - tmp42 + tmp69 = (tmp68 % tmp54) + tmp70 = tmp69 != tmp57 + tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0 + tmp72 = tmp71 != tmp60 + tmp73 = tmp70 & tmp72 + tmp74 = tmp69 + tmp54 + tmp75 = tl.where(tmp73, tmp74, tmp69) + tmp76 = tmp75 == tmp57 + tmp77 = tmp67 & tmp76 + tmp78 = tmp53 | tmp77 + mask_mod_output = tmp78 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp79 = (dsT) + grad_scores = tmp79 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) diff --git a/SpecForge-ext/cache/compiled_kernels/lf/clf3ith7wqaltbvr2kqy5ikzogwedd7ncymtl24hawio7idegek3.py b/SpecForge-ext/cache/compiled_kernels/lf/clf3ith7wqaltbvr2kqy5ikzogwedd7ncymtl24hawio7idegek3.py new file mode 100644 index 0000000000000000000000000000000000000000..d54f876f151d9d7da4f389778b83c3746f266105 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/lf/clf3ith7wqaltbvr2kqy5ikzogwedd7ncymtl24hawio7idegek3.py @@ -0,0 +1,86 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 32, 'r0_': 16}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr4': '*i32', 'out_ptr5': '*i32', 'out_ptr6': '*i32', 'out_ptr7': '*i32', 'out_ptr8': '*i32', 'out_ptr9': '*i32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2', 'mutated_arg_names': ['out_ptr7', 'out_ptr9'], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2(in_ptr0, out_ptr4, out_ptr5, out_ptr6, out_ptr7, out_ptr8, out_ptr9, xnumel, r0_numel, XBLOCK : tl.constexpr): + xnumel = 32 + r0_numel = 16 + R0_BLOCK: tl.constexpr = 16 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + x0 = xindex + tmp0 = tl.load(in_ptr0 + (r0_1 + 16*x0), xmask, other=0.0) + tmp1 = tl.full([1, 1], 0, tl.int64) + tmp2 = tmp0 > tmp1 + tmp3 = tl.full([1, 1], 16384, tl.int64) + tmp4 = tmp0 < tmp3 + tmp5 = tmp2 & tmp4 + tmp6 = tmp5.to(tl.int8) + tmp7 = tmp6.to(tl.int32) + tmp8 = r0_1 + tmp9 = tmp8.to(tl.int16) + tmp10 = tl.broadcast_to(tmp7, [XBLOCK, R0_BLOCK]) + tmp11 = tl.broadcast_to(tmp9, [XBLOCK, R0_BLOCK]) + tmp12, tmp13, = triton_helpers.sort_with_index(tmp10, tmp11, None, 1, stable=True, descending=True) + tmp14 = tmp0 == tmp3 + tmp15 = tmp14.to(tl.int8) + tmp16 = tmp15.to(tl.int32) + tmp17 = tl.broadcast_to(tmp16, [XBLOCK, R0_BLOCK]) + tmp18, tmp19, = triton_helpers.sort_with_index(tmp17, tmp11, None, 1, stable=True, descending=True) + tmp20 = tmp7.to(tl.int64) + tmp21 = tl.broadcast_to(tmp20, [XBLOCK, R0_BLOCK]) + tmp23 = tl.where(xmask, tmp21, 0) + tmp24 = tl.sum(tmp23, 1)[:, None].to(tl.int64) + tmp25 = tmp16.to(tl.int64) + tmp26 = tl.broadcast_to(tmp25, [XBLOCK, R0_BLOCK]) + tmp28 = tl.where(xmask, tmp26, 0) + tmp29 = tl.sum(tmp28, 1)[:, None].to(tl.int64) + tmp30 = tmp24.to(tl.int32) + tmp31 = tmp29.to(tl.int32) + tmp32 = tmp13.to(tl.int64) + tmp33 = tmp32.to(tl.int32) + tmp34 = tmp8 < tmp30 + tmp35 = tl.full([1, 1], 16, tl.int32) + tmp36 = tl.where(tmp34, tmp33, tmp35) + tmp37 = tl.full([XBLOCK, R0_BLOCK], 17, tl.int32) + tmp38 = tmp36 + tmp37 + tmp39 = tmp36 < 0 + tmp40 = tl.where(tmp39, tmp38, tmp36) + tl.device_assert(((0 <= tmp40) & (tmp40 < 17)) | ~(xmask), "index out of bounds: 0 <= tmp40 < 17") + tmp42 = tl.full([1, 1], 1, tl.int32) + tmp43 = tmp19.to(tl.int64) + tmp44 = tmp43.to(tl.int32) + tmp45 = tmp8 < tmp31 + tmp46 = tl.where(tmp45, tmp44, tmp35) + tmp47 = tmp46 + tmp37 + tmp48 = tmp46 < 0 + tmp49 = tl.where(tmp48, tmp47, tmp46) + tl.device_assert(((0 <= tmp49) & (tmp49 < 17)) | ~(xmask), "index out of bounds: 0 <= tmp49 < 17") + tl.store(out_ptr4 + (x0), tmp30, xmask) + tl.store(out_ptr5 + (x0), tmp31, xmask) + tl.store(out_ptr6 + (r0_1 + 16*x0), tmp33, xmask) + tl.store(out_ptr7 + (tl.broadcast_to(tmp40 + 17*x0, [XBLOCK, R0_BLOCK])), tmp42, xmask) + tl.store(out_ptr8 + (r0_1 + 16*x0), tmp44, xmask) + tl.store(out_ptr9 + (tl.broadcast_to(tmp49 + 17*x0, [XBLOCK, R0_BLOCK])), tmp42, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/lf/clf5zaesorynqtwo2h3awvvepunjjjbxbrprayniuyz5qqrl4qym.py b/SpecForge-ext/cache/compiled_kernels/lf/clf5zaesorynqtwo2h3awvvepunjjjbxbrprayniuyz5qqrl4qym.py new file mode 100644 index 0000000000000000000000000000000000000000..3627241b4b9caef554b98c972a13eca5737661ca --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/lf/clf5zaesorynqtwo2h3awvvepunjjjbxbrprayniuyz5qqrl4qym.py @@ -0,0 +1,57 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 262144}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_out_ptr0': '*i64', 'in_ptr0': '*bf16', 'in_ptr1': '*i1', 'in_ptr2': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 4096 + r0_numel = 151936 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 151936*x0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tmp11 = tl.load(in_ptr2 + (x0), None, eviction_policy='evict_last') + tmp3 = tl.full([XBLOCK, 1], 151936, tl.int32) + tmp4 = tmp2 + tmp3 + tmp5 = tmp2 < 0 + tmp6 = tl.where(tmp5, tmp4, tmp2) + tl.device_assert((0 <= tmp6) & (tmp6 < 151936), "index out of bounds: 0 <= tmp6 < 151936") + tmp8 = tl.load(in_ptr1 + (tmp6), None, eviction_policy='evict_last').to(tl.int1) + tmp9 = tmp8.to(tl.int32) + tmp10 = tmp9.to(tl.int64) + tmp12 = tmp10 * tmp11 + tl.debug_barrier() + tl.store(in_out_ptr0 + (x0), tmp12, None) diff --git a/SpecForge-ext/cache/compiled_kernels/li/48fcb2c57c5ef66d17d1db760b5dcf889dc3cfb091172d1ab4c382e7b8821f41.best_config b/SpecForge-ext/cache/compiled_kernels/li/48fcb2c57c5ef66d17d1db760b5dcf889dc3cfb091172d1ab4c382e7b8821f41.best_config new file mode 100644 index 0000000000000000000000000000000000000000..3fd0170715dceebca88888383bbe8f15eedaaeca --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/li/48fcb2c57c5ef66d17d1db760b5dcf889dc3cfb091172d1ab4c382e7b8821f41.best_config @@ -0,0 +1 @@ +{"XBLOCK": 128, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 14, "triton_cache_hash": "S3UH64TOYTN473KAATRMGKZ5SLQ46EZYJVPR6TIL7QNMYCB3MSMA"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/li/cli5y72b3d2jp2n2a5rls4qt4wguoyjbqp6hfv4o3teqoyn3l6lx.py b/SpecForge-ext/cache/compiled_kernels/li/cli5y72b3d2jp2n2a5rls4qt4wguoyjbqp6hfv4o3teqoyn3l6lx.py new file mode 100644 index 0000000000000000000000000000000000000000..6085ec2e97296153e75b1604804a8edfa5f3265a --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/li/cli5y72b3d2jp2n2a5rls4qt4wguoyjbqp6hfv4o3teqoyn3l6lx.py @@ -0,0 +1,1083 @@ +# AOT ID: ['13_backward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/27/c27s4qoyzyvf54snkgtay3lqlnoj3bgphotvv5xwczxe6bqovure.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] +# Source node to ATen node mapping: +# Graph fragment: +# %getitem : Tensor "bf16[8, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:7" = PlaceHolder[target=getitem] +# %tangents_1 : Tensor "bf16[8, 32, s37, 128][4096*Max(1, s37), 128*Max(1, s37), 128, 1]cuda:7" = PlaceHolder[target=tangents_1] +# %buf0 : Tensor "bf16[8, 32, s37][32*s37, s37, 1]cuda:7" = PlaceHolder[target=buf0] +# %full_default : Tensor "f32[8, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 32, %primals_10], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:7, pin_memory: False}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_2, %primals_4, %primals_6, %getitem, %getitem_1, %tangents_1, %full_default, %fw_graph0, %joint_graph0, (%primals_10, %primals_11, %primals_13, %primals_9, %primals_17, %primals_20, %primals_22, %primals_25, %primals_27, %primals_30, 128, 128, %mask_graph0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_14, %primals_15)), kwargs = {}) +# return %buf0,%buf1 +triton_red_fused_zeros_0 = async_compile.triton('triton_red_fused_zeros_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 524288, 'r0_': 128}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr1': '*fp32', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_zeros_0(in_ptr0, in_ptr1, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % ks0) + x1 = ((xindex // ks0) % 32) + x2 = xindex // ks1 + x5 = triton_helpers.div_floor_integer(xindex, ks0) + _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + x4 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_3 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_3 + 128*x1 + 4096*x0 + 4096*ks0*x2), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_3 + 128*x0 + 128*x5*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = _tmp4 + tmp3 + _tmp4 = tl.where(r0_mask & xmask, tmp5, _tmp4) + tmp4 = tl.sum(_tmp4, 1)[:, None] + tmp6 = tmp4.to(tl.float32) + tmp7 = 0.0 + tmp8 = tmp6 - tmp7 + tl.store(out_ptr1 + (x4), tmp8, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/2v/c2vbm66z3map72ysgiduadjtps3nnrhjldngw5bzue3cm5xo44w5.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] +# Source node to ATen node mapping: +# Graph fragment: +# %primals_2 : Tensor "bf16[8, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:7" = PlaceHolder[target=primals_2] +# %primals_4 : Tensor "bf16[8, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:7" = PlaceHolder[target=primals_4] +# %primals_6 : Tensor "bf16[8, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:7" = PlaceHolder[target=primals_6] +# %getitem_1 : Tensor "f32[8, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:7" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[8, 32, s37][32*s37, s37, 1]cuda:7" = PlaceHolder[target=buf1] +# %tangents_1 : Tensor "bf16[8, 32, s37, 128][4096*Max(1, s37), 128*Max(1, s37), 128, 1]cuda:7" = PlaceHolder[target=tangents_1] +# %getitem_3 : Tensor "bf16[8, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:7" = PlaceHolder[target=getitem_3] +# %getitem_5 : Tensor "bf16[8, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:7" = PlaceHolder[target=getitem_5] +# %primals_13 : Tensor "i32[8, 1, s99][s99, s99, 1]cuda:7" = PlaceHolder[target=primals_13] +# %primals_9 : Tensor "i32[8, 1, s22, s72][s22*s72, s22*s72, s72, 1]cuda:7" = PlaceHolder[target=primals_9] +# %primals_22 : Tensor "i32[8, 1, s56][s56, s56, 1]cuda:7" = PlaceHolder[target=primals_22] +# %primals_25 : Tensor "i32[8, 1, s84, s53][s53*s84, s53*s84, s53, 1]cuda:7" = PlaceHolder[target=primals_25] +# %primals_17 : Tensor "i32[8, 1, s94][s94, s94, 1]cuda:7" = PlaceHolder[target=primals_17] +# %primals_20 : Tensor "i32[8, 1, s28, s4][s28*s4, s28*s4, s4, 1]cuda:7" = PlaceHolder[target=primals_20] +# %primals_27 : Tensor "i32[8, 1, s100][s100, s100, 1]cuda:7" = PlaceHolder[target=primals_27] +# %primals_30 : Tensor "i32[8, 1, s6, s10][s10*s6, s10*s6, s10, 1]cuda:7" = PlaceHolder[target=primals_30] +# %primals_14 : Tensor "i64[8][1]cuda:7" = PlaceHolder[target=primals_14] +# %full_default : Tensor "f32[8, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 32, %primals_10], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:7, pin_memory: False}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_2, %primals_4, %primals_6, %getitem, %getitem_1, %tangents_1, %full_default, %fw_graph0, %joint_graph0, (%primals_10, %primals_11, %primals_13, %primals_9, %primals_17, %primals_20, %primals_22, %primals_25, %primals_27, %primals_30, 128, 128, %mask_graph0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_14, %primals_15)), kwargs = {}) +# return %getitem_4 +triton_tem_fused_zeros_1 = async_compile.triton('triton_tem_fused_zeros_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32', 'ks6': 'i32', 'ks7': 'i32', 'ks8': 'i32'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_zeros_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128*ks1, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128*ks1, 128, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128*ks1, 128, 1 + + ZQ = 8 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 8 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 8 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = ks5 + stride_q_idx_h = ks6*ks7 + stride_q_idx_n = ks6 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr16 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = ks8 + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp39 = (ds) + grad_scores = tmp39 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp40 = (qkT) + post_mod_scores = tmp40 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp41 = tl.full([1], False, tl.int1) + tmp42 = (m) + tmp43 = (n) + tmp44 = tmp42 >= tmp43 + tmp45 = tmp43.to(tl.int64) + tmp46 = (off_z) + tmp47 = tl.load(in_ptr16 + tmp46) + tmp48 = tmp45 < tmp47 + tmp49 = tmp42.to(tl.int64) + tmp50 = tmp49 < tmp47 + tmp51 = tmp48 & tmp50 + tmp52 = tmp44 & tmp51 + tmp53 = tmp41 | tmp52 + tmp54 = ks8 + tmp55 = tmp43 >= tmp54 + tmp56 = (tmp43 % tmp54) + tmp57 = tl.full([1], 0, tl.int32) + tmp58 = tmp56 != tmp57 + tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0 + tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0 + tmp61 = tmp59 != tmp60 + tmp62 = tmp58 & tmp61 + tmp63 = tmp56 + tmp54 + tmp64 = tl.where(tmp62, tmp63, tmp56) + tmp65 = tmp64.to(tl.int64) + tmp66 = tmp65 < tmp47 + tmp67 = tmp55 & tmp66 + tmp68 = tmp43 - tmp42 + tmp69 = (tmp68 % tmp54) + tmp70 = tmp69 != tmp57 + tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0 + tmp72 = tmp71 != tmp60 + tmp73 = tmp70 & tmp72 + tmp74 = tmp69 + tmp54 + tmp75 = tl.where(tmp73, tmp74, tmp69) + tmp76 = tmp75 == tmp57 + tmp77 = tmp67 & tmp76 + tmp78 = tmp53 | tmp77 + mask_mod_output = tmp78 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp79 = (dsT) + grad_scores = tmp79 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_10, primals_11, primals_15, primals_7, primals_8, primals_12, primals_16, primals_18, primals_19, primals_21, primals_24, primals_23, primals_26, primals_29, primals_28, primals_2, primals_4, primals_6, primals_9, primals_13, primals_14, primals_17, primals_20, primals_22, primals_25, primals_27, primals_30, getitem, getitem_1, tangents_1 = args + args.clear() + s37 = primals_10 + s0 = primals_11 + s75 = primals_15 + s22 = primals_7 + s72 = primals_8 + s99 = primals_12 + s94 = primals_16 + s28 = primals_18 + s4 = primals_19 + s56 = primals_21 + s53 = primals_24 + s84 = primals_23 + s100 = primals_26 + s10 = primals_29 + s6 = primals_28 + assert_size_stride(primals_2, (8, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(primals_4, (8, 8, s0, 128), (1024*s0, 128*s0, 128, 1)) + assert_size_stride(primals_6, (8, 8, s0, 128), (1024*s0, 128*s0, 128, 1)) + assert_size_stride(primals_9, (8, 1, s22, s72), (s22*s72, s22*s72, s72, 1)) + assert_size_stride(primals_13, (8, 1, s99), (s99, s99, 1)) + assert_size_stride(primals_14, (8, ), (1, )) + assert_size_stride(primals_17, (8, 1, s94), (s94, s94, 1)) + assert_size_stride(primals_20, (8, 1, s28, s4), (s28*s4, s28*s4, s4, 1)) + assert_size_stride(primals_22, (8, 1, s56), (s56, s56, 1)) + assert_size_stride(primals_25, (8, 1, s84, s53), (s53*s84, s53*s84, s53, 1)) + assert_size_stride(primals_27, (8, 1, s100), (s100, s100, 1)) + assert_size_stride(primals_30, (8, 1, s6, s10), (s10*s6, s10*s6, s10, 1)) + assert_size_stride(getitem, (8, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(getitem_1, (8, 32, s37), (32*max(1, s37), max(1, s37), 1)) + assert_size_stride(tangents_1, (8, 32, s37, 128), (4096*max(1, s37), 128*max(1, s37), 128, 1)) + with torch.cuda._DeviceGuard(7): + torch.cuda.set_device(7) + ps0 = 32*s37 + buf1 = empty_strided_cuda((8, 32, s37), (32*s37, s37, 1), torch.float32) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] + triton_red_fused_zeros_0_xnumel = 256*s37 + stream7 = get_raw_stream(7) + triton_red_fused_zeros_0.run(getitem, tangents_1, buf1, s37, ps0, triton_red_fused_zeros_0_xnumel, 128, stream=stream7) + del getitem + buf3 = empty_strided_cuda((8, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16) + buf4 = empty_strided_cuda((8, 8, s0, 128), (1024*s0, 128*s0, 128, 1), torch.bfloat16) + buf5 = empty_strided_cuda((8, 8, s0, 128), (1024*s0, 128*s0, 128, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] + stream7 = get_raw_stream(7) + triton_tem_fused_zeros_1.run(primals_2, primals_4, primals_6, getitem_1, buf1, tangents_1, buf3, buf4, primals_13, primals_9, primals_22, primals_25, primals_17, primals_20, primals_27, primals_30, primals_14, buf5, s37, s0, s99, s22, s72, s56, s53, s84, s75, 4*((127 + s37) // 128) + ((127 + s0) // 128), 8, 8, stream=stream7) + del buf1 + del getitem_1 + del primals_13 + del primals_14 + del primals_17 + del primals_2 + del primals_20 + del primals_22 + del primals_25 + del primals_27 + del primals_30 + del primals_4 + del primals_6 + del primals_9 + del tangents_1 + return (None, buf3, None, buf5, None, buf4, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_10 = 2025 + primals_11 = 2025 + primals_15 = 2025 + primals_7 = 16 + primals_8 = 16 + primals_12 = 16 + primals_16 = 16 + primals_18 = 16 + primals_19 = 16 + primals_21 = 16 + primals_24 = 16 + primals_23 = 16 + primals_26 = 16 + primals_29 = 16 + primals_28 = 16 + primals_2 = rand_strided((8, 32, 2025, 128), (8294400, 128, 4096, 1), device='cuda:7', dtype=torch.bfloat16) + primals_4 = rand_strided((8, 8, 2025, 128), (2073600, 259200, 128, 1), device='cuda:7', dtype=torch.bfloat16) + primals_6 = rand_strided((8, 8, 2025, 128), (2073600, 259200, 128, 1), device='cuda:7', dtype=torch.bfloat16) + primals_9 = rand_strided((8, 1, 16, 16), (256, 256, 16, 1), device='cuda:7', dtype=torch.int32) + primals_13 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:7', dtype=torch.int32) + primals_14 = rand_strided((8, ), (1, ), device='cuda:7', dtype=torch.int64) + primals_17 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:7', dtype=torch.int32) + primals_20 = rand_strided((8, 1, 16, 16), (256, 256, 16, 1), device='cuda:7', dtype=torch.int32) + primals_22 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:7', dtype=torch.int32) + primals_25 = rand_strided((8, 1, 16, 16), (256, 256, 16, 1), device='cuda:7', dtype=torch.int32) + primals_27 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:7', dtype=torch.int32) + primals_30 = rand_strided((8, 1, 16, 16), (256, 256, 16, 1), device='cuda:7', dtype=torch.int32) + getitem = rand_strided((8, 32, 2025, 128), (8294400, 128, 4096, 1), device='cuda:7', dtype=torch.bfloat16) + getitem_1 = rand_strided((8, 32, 2025), (64800, 2025, 1), device='cuda:7', dtype=torch.float32) + tangents_1 = rand_strided((8, 32, 2025, 128), (8294400, 259200, 128, 1), device='cuda:7', dtype=torch.bfloat16) + fn = lambda: call([primals_10, primals_11, primals_15, primals_7, primals_8, primals_12, primals_16, primals_18, primals_19, primals_21, primals_24, primals_23, primals_26, primals_29, primals_28, primals_2, primals_4, primals_6, primals_9, primals_13, primals_14, primals_17, primals_20, primals_22, primals_25, primals_27, primals_30, getitem, getitem_1, tangents_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/li/climzgziptsjlz4tsuxacvt6muyoeemo7l2bxcyotfjy3vij6hbt.py b/SpecForge-ext/cache/compiled_kernels/li/climzgziptsjlz4tsuxacvt6muyoeemo7l2bxcyotfjy3vij6hbt.py new file mode 100644 index 0000000000000000000000000000000000000000..125a892a9c6b8349d3526099a80ffb5d2452535b --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/li/climzgziptsjlz4tsuxacvt6muyoeemo7l2bxcyotfjy3vij6hbt.py @@ -0,0 +1,40 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 512}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i32', 'out_ptr0': '*i32', 'out_ptr1': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3', 'mutated_arg_names': ['out_ptr1'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3(in_ptr0, in_ptr1, out_ptr0, out_ptr1, ks0, ks1, ks2, ks3, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = (xindex % ks0) + x1 = ((xindex // ks0) % ks1) + x2 = xindex // ks2 + x3 = xindex // ks0 + tmp0 = tl.load(in_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))) + x2*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1)))), xmask, eviction_policy='evict_last') + tmp2 = tl.load(in_ptr1 + (x3), xmask, eviction_policy='evict_last') + tmp1 = tmp0.to(tl.int32) + tmp3 = x0 + tmp4 = tmp3 < tmp2 + tmp5 = ks0 + tmp6 = tl.where(tmp4, tmp1, tmp5) + tmp7 = 1 + ks0 + tmp8 = tmp6 + tmp7 + tmp9 = tmp6 < 0 + tmp10 = tl.where(tmp9, tmp8, tmp6) + tl.device_assert(((0 <= tmp10) & (tmp10 < 1 + (triton_helpers.div_floor_integer(127 + ks3, 128)))) | ~(xmask), "index out of bounds: 0 <= tmp10 < 1 + (triton_helpers.div_floor_integer(127 + ks3, 128))") + tmp12 = tl.full([1], 1, tl.int32) + tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))) + x2*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1)))), tmp1, xmask) + tl.store(out_ptr1 + (tmp10 + x3 + ks0*x3), tmp12, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/li/clizd63k3q2lczy7rdushloczo7hziyvb4ouvzpoxoc43es6cpoc.py b/SpecForge-ext/cache/compiled_kernels/li/clizd63k3q2lczy7rdushloczo7hziyvb4ouvzpoxoc43es6cpoc.py new file mode 100644 index 0000000000000000000000000000000000000000..c6dc6c4aca27fefcf67fbe4cdf3a685f49d2064c --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/li/clizd63k3q2lczy7rdushloczo7hziyvb4ouvzpoxoc43es6cpoc.py @@ -0,0 +1,552 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'in_ptr9': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128*ks1, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128*ks1, 128, 1 + + ZQ = 2 + HQ = 32 + Q_LEN = ks0 + ZKV = 2 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 2 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m + 4096*idx_zq*ks0, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr9 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = ks5 + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/lq/ac09f96f5cced6d777a7f4212ede2d08394877ee82e6671455439215a023a259.best_config b/SpecForge-ext/cache/compiled_kernels/lq/ac09f96f5cced6d777a7f4212ede2d08394877ee82e6671455439215a023a259.best_config new file mode 100644 index 0000000000000000000000000000000000000000..b3edbf99a0eba8d80382a60b6e8c1e8217d859f3 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/lq/ac09f96f5cced6d777a7f4212ede2d08394877ee82e6671455439215a023a259.best_config @@ -0,0 +1 @@ +{"XBLOCK": 256, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 11, "triton_cache_hash": "L4G4T7LJLJ5UKCYTJO6FX7X7X5CAA5AHUH7H57L6AM4BNGXXAAVQ"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/lq/clq6j324q6ohohfcoim3i3nwrcoxenqnrdqfhsu6rmdyv2f3htkh.py b/SpecForge-ext/cache/compiled_kernels/lq/clq6j324q6ohohfcoim3i3nwrcoxenqnrdqfhsu6rmdyv2f3htkh.py new file mode 100644 index 0000000000000000000000000000000000000000..413357e692c72e02196a55c21158d69d425631b8 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/lq/clq6j324q6ohohfcoim3i3nwrcoxenqnrdqfhsu6rmdyv2f3htkh.py @@ -0,0 +1,26 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 512}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr0': '*i32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_clone_slice_4', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_clone_slice_4(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = (xindex % ks0) + x1 = xindex // ks0 + x2 = xindex + tmp0 = tl.load(in_ptr0 + (x0 + x1 + ks0*x1), xmask, eviction_policy='evict_last') + tl.store(out_ptr0 + (x2), tmp0, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/lq/clq6xm6mlrdi6knqndoe2uh6tqzfa6wzx3k26ksk5imj6b4td57c.py b/SpecForge-ext/cache/compiled_kernels/lq/clq6xm6mlrdi6knqndoe2uh6tqzfa6wzx3k26ksk5imj6b4td57c.py new file mode 100644 index 0000000000000000000000000000000000000000..91f4cbcecf761200b41b51c7de3a9d97dfc7d5c2 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/lq/clq6xm6mlrdi6knqndoe2uh6tqzfa6wzx3k26ksk5imj6b4td57c.py @@ -0,0 +1,46 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 16384, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 262144, 'r0_': 1048576000}} +) +@triton.jit +def triton_red_fused_argmax_0(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 16384 + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tl.store(out_ptr0 + (x0), tmp2, None) diff --git a/SpecForge-ext/cache/compiled_kernels/lq/clqbbuegamppef5fckuka5vp5vs5u35wgqa2yxhsbyvq2waia5y4.py b/SpecForge-ext/cache/compiled_kernels/lq/clqbbuegamppef5fckuka5vp5vs5u35wgqa2yxhsbyvq2waia5y4.py new file mode 100644 index 0000000000000000000000000000000000000000..d6c636aa9ac55b8fe5947ee5a8a88169fab8b2db --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/lq/clqbbuegamppef5fckuka5vp5vs5u35wgqa2yxhsbyvq2waia5y4.py @@ -0,0 +1,25 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 1024}, + filename=__file__, + triton_meta={'signature': {'out_ptr0': '*i32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_new_zeros_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 4352}}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_new_zeros_1(out_ptr0, xnumel, XBLOCK : tl.constexpr): + xnumel = 544 + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.full([1], 0, tl.int32) + tl.store(out_ptr0 + (x0), tmp0, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/lq/f128dce654d98dd22e1dbb5e11e076c021abc3f6727ca00106d8d78fbef2df6d.best_config b/SpecForge-ext/cache/compiled_kernels/lq/f128dce654d98dd22e1dbb5e11e076c021abc3f6727ca00106d8d78fbef2df6d.best_config new file mode 100644 index 0000000000000000000000000000000000000000..e429df07e213631abc2aaaa2b257dd2a025bc297 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/lq/f128dce654d98dd22e1dbb5e11e076c021abc3f6727ca00106d8d78fbef2df6d.best_config @@ -0,0 +1 @@ +{"XBLOCK": 128, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 11, "triton_cache_hash": "X4SFCUNHNVK6FR6CSIUU4JIDJXVPMITMWOHHGKRF3QCUQNY7M77Q"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/ml/cmlgrfrpyvl5gqlphdfiqxpqawlo6wyjtpwqkky6zywkkdk4h4hl.py b/SpecForge-ext/cache/compiled_kernels/ml/cmlgrfrpyvl5gqlphdfiqxpqawlo6wyjtpwqkky6zywkkdk4h4hl.py new file mode 100644 index 0000000000000000000000000000000000000000..93703ded1483dae5e073b5712149dc4576076635 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/ml/cmlgrfrpyvl5gqlphdfiqxpqawlo6wyjtpwqkky6zywkkdk4h4hl.py @@ -0,0 +1,57 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 262144}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_out_ptr0': '*i64', 'in_ptr0': '*bf16', 'in_ptr1': '*i1', 'in_ptr2': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 4096 + r0_numel = 151936 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 151936*x0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tmp11 = tl.load(in_ptr2 + (x0), None, eviction_policy='evict_last') + tmp3 = tl.full([XBLOCK, 1], 151936, tl.int32) + tmp4 = tmp2 + tmp3 + tmp5 = tmp2 < 0 + tmp6 = tl.where(tmp5, tmp4, tmp2) + tl.device_assert((0 <= tmp6) & (tmp6 < 151936), "index out of bounds: 0 <= tmp6 < 151936") + tmp8 = tl.load(in_ptr1 + (tmp6), None, eviction_policy='evict_last').to(tl.int1) + tmp9 = tmp8.to(tl.int32) + tmp10 = tmp9.to(tl.int64) + tmp12 = tmp10 * tmp11 + tl.debug_barrier() + tl.store(in_out_ptr0 + (x0), tmp12, None) diff --git a/SpecForge-ext/cache/compiled_kernels/mq/cmq33rjxz47uciqulhibuqonab77aboxlu3alymsqaa5zvrwrl5c.py b/SpecForge-ext/cache/compiled_kernels/mq/cmq33rjxz47uciqulhibuqonab77aboxlu3alymsqaa5zvrwrl5c.py new file mode 100644 index 0000000000000000000000000000000000000000..39e7fa2731889125a22023a465967c984729c96a --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/mq/cmq33rjxz47uciqulhibuqonab77aboxlu3alymsqaa5zvrwrl5c.py @@ -0,0 +1,72 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 512, 'r0_': 16384}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 8192, 'r0_': 0}} +) +@triton.jit +def triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 512 + r0_numel = 16384 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x1 = ((xindex // 16) % 16) + x0 = (xindex % 16) + x2 = xindex // 256 + tmp3 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last') + _tmp29 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + x6 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_4 = r0_index // 128 + r0_3 = (r0_index % 128) + tmp0 = r0_4 + 128*x1 + tmp1 = r0_3 + 128*x0 + tmp2 = tmp0 >= tmp1 + tmp4 = tmp1 < tmp3 + tmp5 = tmp0 < tmp3 + tmp6 = tmp4 & tmp5 + tmp7 = tmp2 & tmp6 + tmp8 = tl.full([1, 1], False, tl.int1) + tmp9 = tmp8 | tmp7 + tmp10 = tl.full([1, 1], 2048, tl.int64) + tmp11 = tmp1 >= tmp10 + tmp12 = tmp11 & tmp4 + tmp13 = r0_3 + ((-1)*r0_4) + ((-128)*x1) + 128*x0 + tmp14 = (tmp13 % tmp10) + tmp15 = tl.full([1, 1], 0, tl.int32) + tmp16 = tmp14 != tmp15 + tmp17 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp18 = (libdevice.signbit(tmp10) != 0) if (tmp10).dtype is tl.float32 else tmp10 < 0 + tmp19 = tmp17 != tmp18 + tmp20 = tmp16 & tmp19 + tmp21 = tmp14 + tmp10 + tmp22 = tl.where(tmp20, tmp21, tmp14) + tmp23 = tl.full([1, 1], 0, tl.int64) + tmp24 = tmp22 == tmp23 + tmp25 = tmp12 & tmp24 + tmp26 = tmp9 | tmp25 + tmp27 = tmp26.to(tl.int64) + tmp28 = tl.broadcast_to(tmp27, [XBLOCK, R0_BLOCK]) + tmp30 = _tmp29 + tmp28 + _tmp29 = tl.where(r0_mask & xmask, tmp30, _tmp29) + tmp29 = tl.sum(_tmp29, 1)[:, None] + tl.store(out_ptr0 + (x6), tmp29, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/my/cmyh7eqzbw2eee2wnsvk7buhbk2gj2k6pknksde6c7lujjf7hxnq.py b/SpecForge-ext/cache/compiled_kernels/my/cmyh7eqzbw2eee2wnsvk7buhbk2gj2k6pknksde6c7lujjf7hxnq.py new file mode 100644 index 0000000000000000000000000000000000000000..8c86b758f612a85e8dfc6fd1f90ecf0279c52b9f --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/my/cmyh7eqzbw2eee2wnsvk7buhbk2gj2k6pknksde6c7lujjf7hxnq.py @@ -0,0 +1,552 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'in_ptr9': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128*ks1, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128*ks1, 128, 1 + + ZQ = 8 + HQ = 32 + Q_LEN = ks0 + ZKV = 8 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 8 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m + 4096*idx_zq*ks0, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr9 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = ks5 + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/my/cmyvhpv2dpwemg2yd5ylraj2qaxki4qorus55k33d6ctudwlwrec.py b/SpecForge-ext/cache/compiled_kernels/my/cmyvhpv2dpwemg2yd5ylraj2qaxki4qorus55k33d6ctudwlwrec.py new file mode 100644 index 0000000000000000000000000000000000000000..b678156de34fa6ecee3450f7a423f731d2120e6e --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/my/cmyvhpv2dpwemg2yd5ylraj2qaxki4qorus55k33d6ctudwlwrec.py @@ -0,0 +1,43 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 128, 'r0_': 16}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr1': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_sum_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_sum_2(in_ptr0, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp3 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp1 = tmp0.to(tl.int64) + tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) + tmp4 = _tmp3 + tmp2 + _tmp3 = tl.where(r0_mask & xmask, tmp4, _tmp3) + tmp3 = tl.sum(_tmp3, 1)[:, None] + x2 = (xindex % ks1) + x3 = xindex // ks1 + tmp5 = tmp3.to(tl.int32) + tl.store(out_ptr1 + (x2 + x3*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1)))), tmp5, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/n5/62f0a14e75ab502cfd5f16197242b120cc1cc5099a3ef8b7eda2ac679d5e2f2d.best_config b/SpecForge-ext/cache/compiled_kernels/n5/62f0a14e75ab502cfd5f16197242b120cc1cc5099a3ef8b7eda2ac679d5e2f2d.best_config new file mode 100644 index 0000000000000000000000000000000000000000..65940914d216d062f2725b003671f86bc9112489 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/n5/62f0a14e75ab502cfd5f16197242b120cc1cc5099a3ef8b7eda2ac679d5e2f2d.best_config @@ -0,0 +1 @@ +{"XBLOCK": 512, "num_warps": 8, "num_stages": 1, "configs_hash": "3ca5c3e34d35093f3c9ab2829a9faeebad5e61c4ca13d5ed6053d7b71ce60d5a", "found_by_coordesc": false, "time_taken_ms": 56, "triton_cache_hash": "XRR2QXTZQK4DSBTDJUTNXO6FEFXI2IIRKSC5GYSBWLTL56SKI4WA"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/n5/cn5h4iq6wlljobax2ulslga4k6zxontovelmyztexccj4qb2xkei.py b/SpecForge-ext/cache/compiled_kernels/n5/cn5h4iq6wlljobax2ulslga4k6zxontovelmyztexccj4qb2xkei.py new file mode 100644 index 0000000000000000000000000000000000000000..aa93a52e46259a42fd19c9e544e295b4ce3b2f37 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/n5/cn5h4iq6wlljobax2ulslga4k6zxontovelmyztexccj4qb2xkei.py @@ -0,0 +1,66 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 16777216}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 6, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = (xindex % ks0) + x3 = xindex + x1 = ((xindex // ks0) % ks1) + tmp31 = tl.load(in_ptr0 + (x3), xmask, eviction_policy='evict_last').to(tl.float32) + tmp32 = tl.load(in_ptr1 + (x1), xmask, eviction_policy='evict_last') + tmp0 = x0 + tmp1 = ks0 // 2 + tmp2 = tmp0 >= tmp1 + tmp3 = tl.load(in_ptr0 + (x3 + (-1)*(ks0 // 2)), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp4 = tl.load(in_ptr1 + (x1), tmp2 & xmask, eviction_policy='evict_last', other=0.0) + tmp5 = tl.broadcast_to(ks2, [XBLOCK]) + tmp6 = tmp4 + tmp5 + tmp7 = tmp4 < 0 + tmp8 = tl.where(tmp7, tmp6, tmp4) + tl.device_assert(((0 <= tl.broadcast_to(tmp8, [XBLOCK])) & (tl.broadcast_to(tmp8, [XBLOCK]) < ks2)) | ~(tmp2 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp8, [XBLOCK]) < ks2") + tmp10 = tl.load(in_ptr2 + (x0 + (-1)*(ks0 // 2) + ks0*tmp8), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp11 = tmp3 * tmp10 + tmp12 = -tmp11 + tmp13 = tl.full(tmp12.shape, 0.0, tmp12.dtype) + tmp14 = tl.where(tmp2, tmp12, tmp13) + tmp15 = 0.0 + tmp16 = tl.where(tmp2, tmp14, tmp15) + tmp17 = tmp0 < tmp1 + tmp18 = tl.load(in_ptr0 + (ks0 + x3 + (-1)*(ks0 // 2)), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp19 = tl.load(in_ptr1 + (x1), tmp17 & xmask, eviction_policy='evict_last', other=0.0) + tmp20 = tl.broadcast_to(ks2, [XBLOCK]) + tmp21 = tmp19 + tmp20 + tmp22 = tmp19 < 0 + tmp23 = tl.where(tmp22, tmp21, tmp19) + tl.device_assert(((0 <= tl.broadcast_to(tmp23, [XBLOCK])) & (tl.broadcast_to(tmp23, [XBLOCK]) < ks2)) | ~(tmp17 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp23, [XBLOCK]) < ks2") + tmp25 = tl.load(in_ptr2 + (ks0 + x0 + (-1)*(ks0 // 2) + ks0*tmp23), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp26 = tmp18 * tmp25 + tmp27 = tl.full(tmp26.shape, 0.0, tmp26.dtype) + tmp28 = tl.where(tmp17, tmp26, tmp27) + tmp29 = tl.where(tmp17, tmp28, tmp15) + tmp30 = tmp16 + tmp29 + tmp33 = ks3 + tmp34 = tmp32 + tmp33 + tmp35 = tmp32 < 0 + tmp36 = tl.where(tmp35, tmp34, tmp32) + tl.device_assert(((0 <= tmp36) & (tmp36 < ks3)) | ~(xmask), "index out of bounds: 0 <= tmp36 < ks3") + tmp38 = tl.load(in_ptr3 + (x0 + ks0*tmp36), xmask, eviction_policy='evict_last').to(tl.float32) + tmp39 = tmp31 * tmp38 + tmp40 = tmp30 + tmp39 + tl.store(out_ptr0 + (x3), tmp40, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/nc/0b6ab8e9248d19cc895196ea9dfce1cc1db0f1098de2ab269b92b1744cb166eb.best_config b/SpecForge-ext/cache/compiled_kernels/nc/0b6ab8e9248d19cc895196ea9dfce1cc1db0f1098de2ab269b92b1744cb166eb.best_config new file mode 100644 index 0000000000000000000000000000000000000000..2a95815b49cfc301dd2a3d06bb1b105b04bfbae7 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/nc/0b6ab8e9248d19cc895196ea9dfce1cc1db0f1098de2ab269b92b1744cb166eb.best_config @@ -0,0 +1 @@ +{"XBLOCK": 256, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 11, "triton_cache_hash": "XAIV2GWX5UZL7NNOCKNWC2I6AATKI6664P6FTQPRXS2M4AR4WJWA"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/nc/cncbngvvcrqapqfcl52tmrjnd5x6fl3wkde52kfgv4tst3afiraj.py b/SpecForge-ext/cache/compiled_kernels/nc/cncbngvvcrqapqfcl52tmrjnd5x6fl3wkde52kfgv4tst3afiraj.py new file mode 100644 index 0000000000000000000000000000000000000000..c0ff8ae6f37dd92901c86a459397560577d17fd0 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/nc/cncbngvvcrqapqfcl52tmrjnd5x6fl3wkde52kfgv4tst3afiraj.py @@ -0,0 +1,552 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'in_ptr9': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128*ks1, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128*ks1, 128, 1 + + ZQ = 2 + HQ = 32 + Q_LEN = ks0 + ZKV = 2 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 2 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m + 4096*idx_zq*ks0, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr9 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = ks5 + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/nc/cncp6gc3dmuptzsqtubfwfropeokpi3finyiksypukotgxjwwyrd.py b/SpecForge-ext/cache/compiled_kernels/nc/cncp6gc3dmuptzsqtubfwfropeokpi3finyiksypukotgxjwwyrd.py new file mode 100644 index 0000000000000000000000000000000000000000..5493c69f2353d159d90c6bf97c299ce5c9f41496 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/nc/cncp6gc3dmuptzsqtubfwfropeokpi3finyiksypukotgxjwwyrd.py @@ -0,0 +1,552 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'in_ptr9': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks0, 128*ks0, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks0, 128*ks0, 128, 1 + + ZQ = 8 + HQ = 32 + Q_LEN = 2048 + ZKV = 8 + KV_LEN = ks0 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 8 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 16*ks1 + stride_kv_idx_m = ks1 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 262144*idx_hq + 8388608*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m + 8388608*idx_zq, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr9 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/nc/cncrmz2ukoaup352257exrmdm3fmr2vqg5omnvcupywjpmui4s36.py b/SpecForge-ext/cache/compiled_kernels/nc/cncrmz2ukoaup352257exrmdm3fmr2vqg5omnvcupywjpmui4s36.py new file mode 100644 index 0000000000000000000000000000000000000000..3dbb3486872901f274aab825f7af13d98884013e --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/nc/cncrmz2ukoaup352257exrmdm3fmr2vqg5omnvcupywjpmui4s36.py @@ -0,0 +1,552 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'in_ptr9': '*i64', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 2097152, 262144, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 2097152, 262144, 128, 1 + + ZQ = 2 + HQ = 32 + Q_LEN = 2048 + ZKV = 2 + KV_LEN = 2048 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 2 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 256 + stride_kv_idx_m = 16 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 262144*idx_hq + 8388608*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m + 8388608*idx_zq, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr9 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/nc/cnczal5hkwpmbxl4uatxs3oesqt3wbtzlonaaxjzwde4knlxefk5.py b/SpecForge-ext/cache/compiled_kernels/nc/cnczal5hkwpmbxl4uatxs3oesqt3wbtzlonaaxjzwde4knlxefk5.py new file mode 100644 index 0000000000000000000000000000000000000000..98c3bf092eea31f747ff735aff2498ea26bb7a0e --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/nc/cnczal5hkwpmbxl4uatxs3oesqt3wbtzlonaaxjzwde4knlxefk5.py @@ -0,0 +1,24 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 2048}, + filename=__file__, + triton_meta={'signature': {'out_ptr0': '*i32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_new_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_new_zeros_0(out_ptr0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.full([1], 0, tl.int32) + tl.store(out_ptr0 + (x0), tmp0, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/ne/89d9ff0986e5f4744b4462fc7593d7a2747c2a53a15b926d769aa2c1528fdc67.best_config b/SpecForge-ext/cache/compiled_kernels/ne/89d9ff0986e5f4744b4462fc7593d7a2747c2a53a15b926d769aa2c1528fdc67.best_config new file mode 100644 index 0000000000000000000000000000000000000000..7dabf38caa6a32710475fe8672437703d14a89f4 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/ne/89d9ff0986e5f4744b4462fc7593d7a2747c2a53a15b926d769aa2c1528fdc67.best_config @@ -0,0 +1 @@ +{"XBLOCK": 1, "R0_BLOCK": 2048, "num_warps": 16, "num_stages": 1, "configs_hash": "8c03dc2e05d158372838fe4d32248dfba74b467c7576f6e1d3eb472c41b37c80", "found_by_coordesc": false, "time_taken_ms": 213, "triton_cache_hash": "HAILX5Z6XDOE3PGHQZX4ABAJBEEI4J2UQHMI6VLNTQGT6BMZJQQQ"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/ne/cneicvkgqbfhav7auxfs57ppzwqzmhxqyv3gpizurzngng3kdyrg.py b/SpecForge-ext/cache/compiled_kernels/ne/cneicvkgqbfhav7auxfs57ppzwqzmhxqyv3gpizurzngng3kdyrg.py new file mode 100644 index 0000000000000000000000000000000000000000..564acd3b7235becba8a85e825498cf5f875f5870 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/ne/cneicvkgqbfhav7auxfs57ppzwqzmhxqyv3gpizurzngng3kdyrg.py @@ -0,0 +1,62 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 128, 'r0_': 32}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'in_ptr1': '*i64', 'out_ptr1': '*i32', 'out_ptr2': '*i32', 'out_ptr3': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2', 'mutated_arg_names': ['out_ptr3'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2(in_ptr0, in_ptr1, out_ptr1, out_ptr2, out_ptr3, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp3 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp1 = tmp0.to(tl.int64) + tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) + tmp4 = _tmp3 + tmp2 + _tmp3 = tl.where(r0_mask & xmask, tmp4, _tmp3) + tmp3 = tl.sum(_tmp3, 1)[:, None] + tmp5 = tmp3.to(tl.int32) + tl.store(out_ptr1 + (x0), tmp5, xmask) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp6 = tl.load(in_ptr1 + (r0_1 + x0*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp7 = tmp6.to(tl.int32) + tmp8 = r0_1 + tmp9 = tmp8 < tmp5 + tmp10 = ks0 + tmp11 = tl.where(tmp9, tmp7, tmp10) + tmp12 = 1 + ks0 + tmp13 = tmp11 + tmp12 + tmp14 = tmp11 < 0 + tmp15 = tl.where(tmp14, tmp13, tmp11) + tl.device_assert(((0 <= tmp15) & (tmp15 < 1 + (triton_helpers.div_floor_integer(127 + ks1, 128)))) | ~(r0_mask & xmask), "index out of bounds: 0 <= tmp15 < 1 + (triton_helpers.div_floor_integer(127 + ks1, 128))") + tmp17 = tl.full([1, 1], 1, tl.int32) + tl.store(out_ptr2 + (r0_1 + x0*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp7, r0_mask & xmask) + tl.store(out_ptr3 + (tl.broadcast_to(tmp15 + x0 + ks0*x0, [XBLOCK, R0_BLOCK])), tmp17, r0_mask & xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/ne/cnek4kffvmq2az6ibfgdjxclelz4oz3bejintevs3jfoyxv26f5j.py b/SpecForge-ext/cache/compiled_kernels/ne/cnek4kffvmq2az6ibfgdjxclelz4oz3bejintevs3jfoyxv26f5j.py new file mode 100644 index 0000000000000000000000000000000000000000..bd68b9c2e9da420af6af937918073d9be602920a --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/ne/cnek4kffvmq2az6ibfgdjxclelz4oz3bejintevs3jfoyxv26f5j.py @@ -0,0 +1,543 @@ +# AOT ID: ['5_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/vx/cvxpvd4dngzkuy4urlysh3zcjg4xf6tz55pqiqu2rwikym62u5ks.py +# Topologically Sorted Source Nodes: [result_1, m, causal_mask, n, b, index, lt, padding_mask, index_1, lt_1, and_2, suffix_mask, remainder, index_2, padding_mask_1, and_3, and_4, sub, remainder_1, diagnol_mask, result_2, batched_outputs_2, mask_2, mask_3, mask_block_sum], Original ATen: [aten.view, aten.arange, aten.ge, aten.index, aten.lt, aten.bitwise_and, aten.bitwise_or, aten.remainder, aten.sub, aten.eq, aten.permute, aten.sum] +# Source node to ATen node mapping: +# and_2 => bitwise_and_1 +# and_3 => bitwise_and_2 +# and_4 => bitwise_and_3, view_8 +# b => iota +# batched_outputs_2 => view_9 +# causal_mask => ge, view +# diagnol_mask => eq +# index => index +# index_1 => index_1 +# index_2 => index_2 +# lt => lt, view_1 +# lt_1 => lt_1, view_2 +# m => iota_2 +# mask_2 => view_10 +# mask_3 => permute +# mask_block_sum => sum_1 +# n => iota_3 +# padding_mask => bitwise_and, view_3, view_4 +# padding_mask_1 => lt_2, view_6 +# remainder => remainder +# remainder_1 => remainder_1 +# result_1 => bitwise_or, full_default +# result_2 => bitwise_or_1 +# sub => sub, view_7 +# suffix_mask => ge_1 +# Graph fragment: +# %arg0_1 : Tensor "i64[8][1]cuda:7" = PlaceHolder[target=arg0_1] +# %full_default : Tensor "b8[8, 1, 1][1, 1, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 1, 1], False), kwargs = {dtype: torch.bool, layout: torch.strided, device: cuda:7, pin_memory: False}) +# %iota_2 : Tensor "i64[2048][1]cuda:7"[num_users=3] = call_function[target=torch.ops.prims.iota.default](args = (2048,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:7, requires_grad: False}) +# %view : Tensor "i64[2048, 1][1, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%iota_2, [2048, 1]), kwargs = {}) +# %iota_3 : Tensor "i64[2048][1]cuda:7"[num_users=5] = call_function[target=torch.ops.prims.iota.default](args = (2048,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:7, requires_grad: False}) +# %ge : Tensor "b8[2048, 2048][2048, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.ge.Tensor](args = (%view, %iota_3), kwargs = {}) +# %iota : Tensor "i64[8][1]cuda:7"[num_users=3] = call_function[target=torch.ops.prims.iota.default](args = (8,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:7, requires_grad: False}) +# %index : Tensor "i64[8][1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg0_1, [%iota]), kwargs = {}) +# %view_1 : Tensor "i64[8, 1][1, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%index, [8, 1]), kwargs = {}) +# %lt : Tensor "b8[8, 2048][2048, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_3, %view_1), kwargs = {}) +# %view_4 : Tensor "b8[8, 1, 2048][2048, 2048, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%lt, [8, 1, 2048]), kwargs = {}) +# %index_1 : Tensor "i64[8][1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg0_1, [%iota]), kwargs = {}) +# %view_2 : Tensor "i64[8, 1][1, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%index_1, [8, 1]), kwargs = {}) +# %lt_1 : Tensor "b8[8, 2048][2048, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_2, %view_2), kwargs = {}) +# %view_3 : Tensor "b8[8, 2048, 1][2048, 1, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%lt_1, [8, 2048, 1]), kwargs = {}) +# %bitwise_and : Tensor "b8[8, 2048, 2048][4194304, 2048, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%view_4, %view_3), kwargs = {}) +# %bitwise_and_1 : Tensor "b8[8, 2048, 2048][4194304, 2048, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%ge, %bitwise_and), kwargs = {}) +# %bitwise_or : Tensor "b8[8, 2048, 2048][4194304, 2048, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.bitwise_or.Tensor](args = (%full_default, %bitwise_and_1), kwargs = {}) +# %ge_1 : Tensor "b8[2048][1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.ge.Scalar](args = (%iota_3, 2048), kwargs = {}) +# %remainder : Tensor "i64[2048][1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.remainder.Scalar](args = (%iota_3, 2048), kwargs = {}) +# %index_2 : Tensor "i64[8][1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg0_1, [%iota]), kwargs = {}) +# %view_6 : Tensor "i64[8, 1][1, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%index_2, [8, 1]), kwargs = {}) +# %lt_2 : Tensor "b8[8, 2048][2048, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%remainder, %view_6), kwargs = {}) +# %bitwise_and_2 : Tensor "b8[8, 2048][2048, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%ge_1, %lt_2), kwargs = {}) +# %view_8 : Tensor "b8[8, 1, 2048][2048, 2048, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%bitwise_and_2, [8, 1, 2048]), kwargs = {}) +# %view_7 : Tensor "i64[2048, 1][1, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%iota_2, [2048, 1]), kwargs = {}) +# %sub : Tensor "i64[2048, 2048][2048, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%iota_3, %view_7), kwargs = {}) +# %remainder_1 : Tensor "i64[2048, 2048][2048, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.remainder.Scalar](args = (%sub, 2048), kwargs = {}) +# %eq : Tensor "b8[2048, 2048][2048, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%remainder_1, 0), kwargs = {}) +# %bitwise_and_3 : Tensor "b8[8, 2048, 2048][4194304, 2048, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%view_8, %eq), kwargs = {}) +# %bitwise_or_1 : Tensor "b8[8, 2048, 2048][4194304, 2048, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.bitwise_or.Tensor](args = (%bitwise_or, %bitwise_and_3), kwargs = {}) +# %view_9 : Tensor "b8[8, 1, 2048, 2048][4194304, 4194304, 2048, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%bitwise_or_1, [8, 1, 2048, 2048]), kwargs = {}) +# %view_10 : Tensor "b8[8, 1, 16, 128, 16, 128][4194304, 4194304, 262144, 2048, 128, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%expand, [8, 1, 16, 128, 16, 128]), kwargs = {}) +# %permute : Tensor "b8[8, 1, 16, 16, 128, 128][4194304, 4194304, 262144, 128, 2048, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_10, [0, 1, 2, 4, 3, 5]), kwargs = {}) +# %sum_1 : Tensor "i64[8, 1, 16, 16][256, 256, 16, 1]cuda:7"[num_users=3] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%permute, [-2, -1]), kwargs = {}) +# return %sum_1 +triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0 = async_compile.triton('triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 2048, 'r0_': 16384}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 32768, 'r0_': 0}} +) +@triton.jit +def triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 2048 + r0_numel = 16384 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x1 = ((xindex // 16) % 16) + x0 = (xindex % 16) + x2 = xindex // 256 + tmp3 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last') + _tmp29 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + x6 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_4 = r0_index // 128 + r0_3 = (r0_index % 128) + tmp0 = r0_4 + 128*x1 + tmp1 = r0_3 + 128*x0 + tmp2 = tmp0 >= tmp1 + tmp4 = tmp1 < tmp3 + tmp5 = tmp0 < tmp3 + tmp6 = tmp4 & tmp5 + tmp7 = tmp2 & tmp6 + tmp8 = tl.full([1, 1], False, tl.int1) + tmp9 = tmp8 | tmp7 + tmp10 = tl.full([1, 1], 2048, tl.int64) + tmp11 = tmp1 >= tmp10 + tmp12 = tmp11 & tmp4 + tmp13 = r0_3 + ((-1)*r0_4) + ((-128)*x1) + 128*x0 + tmp14 = (tmp13 % tmp10) + tmp15 = tl.full([1, 1], 0, tl.int32) + tmp16 = tmp14 != tmp15 + tmp17 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp18 = (libdevice.signbit(tmp10) != 0) if (tmp10).dtype is tl.float32 else tmp10 < 0 + tmp19 = tmp17 != tmp18 + tmp20 = tmp16 & tmp19 + tmp21 = tmp14 + tmp10 + tmp22 = tl.where(tmp20, tmp21, tmp14) + tmp23 = tl.full([1, 1], 0, tl.int64) + tmp24 = tmp22 == tmp23 + tmp25 = tmp12 & tmp24 + tmp26 = tmp9 | tmp25 + tmp27 = tmp26.to(tl.int64) + tmp28 = tl.broadcast_to(tmp27, [XBLOCK, R0_BLOCK]) + tmp30 = _tmp29 + tmp28 + _tmp29 = tl.where(r0_mask & xmask, tmp30, _tmp29) + tmp29 = tl.sum(_tmp29, 1)[:, None] + tl.store(out_ptr0 + (x6), tmp29, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wf/cwfyyn6luby7jkq6pqhiqb44jz2jln72mtaagrllnlfp5opls7qm.py +# Topologically Sorted Source Nodes: [dense_mask_4], Original ATen: [aten.new_zeros] +# Source node to ATen node mapping: +# dense_mask_4 => full_default_4 +# Graph fragment: +# %full_default_4 : Tensor "i32[8, 1, 16, 17][272, 272, 17, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 1, 16, 17], 0), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:7, pin_memory: False}) +# return %index_put_1 +triton_poi_fused_new_zeros_1 = async_compile.triton('triton_poi_fused_new_zeros_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 4096}, + filename=__file__, + triton_meta={'signature': {'out_ptr0': '*i32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_new_zeros_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 17408}}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_new_zeros_1(out_ptr0, xnumel, XBLOCK : tl.constexpr): + xnumel = 2176 + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.full([1], 0, tl.int32) + tl.store(out_ptr0 + (x0), tmp0, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/6q/c6qvewydivg4bvc4qyzfgybxa6usm4f7lbirtgr37nhmlwfed6p5.py +# Topologically Sorted Source Nodes: [gt, lt_3, partial_blocks, partial_blocks_1, dense_mask, col_indices, full_blocks, full_blocks_1, dense_mask_1, col_indices_1, dense_mask_2, setitem, arange_4, row_indices, col_range, num_blocks_in_row, child_3, unsqueeze_1, index_mask, child_4, valid_indices, dense_mask_4, setitem_1, arange_6, row_indices_1, col_range_1, num_blocks_in_row_1, child_7, unsqueeze_3, index_mask_1, child_8, valid_indices_1], Original ATen: [aten.gt, aten.lt, aten.bitwise_and, aten._to_copy, aten.sort, aten.eq, aten.new_zeros, aten.arange, aten.unsqueeze, aten.sum, aten.scalar_tensor, aten.where, aten.view, aten.index_put] +# Source node to ATen node mapping: +# arange_4 => iota_4 +# arange_6 => iota_8 +# child_3 => convert_element_type_3 +# child_4 => convert_element_type_4 +# child_7 => convert_element_type_6 +# child_8 => convert_element_type_7 +# col_indices => sort +# col_indices_1 => sort_1 +# col_range => iota_5 +# col_range_1 => iota_9 +# dense_mask => convert_element_type_2 +# dense_mask_1 => convert_element_type_5 +# dense_mask_2 => full_default_1 +# dense_mask_4 => full_default_4 +# full_blocks => eq_1 +# full_blocks_1 => convert_element_type_1 +# gt => gt +# index_mask => lt_4 +# index_mask_1 => lt_5 +# lt_3 => lt_3 +# num_blocks_in_row => sum_2 +# num_blocks_in_row_1 => sum_3 +# partial_blocks => bitwise_and_4 +# partial_blocks_1 => convert_element_type +# row_indices => unsqueeze +# row_indices_1 => unsqueeze_7 +# setitem => full_default_3, index_put, iota_6, iota_7, unsqueeze_2, unsqueeze_3, unsqueeze_4, unsqueeze_5, unsqueeze_6 +# setitem_1 => full_default_6, index_put_1, iota_10, iota_11, unsqueeze_10, unsqueeze_11, unsqueeze_12, unsqueeze_13, unsqueeze_9 +# unsqueeze_1 => unsqueeze_1 +# unsqueeze_3 => unsqueeze_8 +# valid_indices => full_default_2, where +# valid_indices_1 => full_default_5, where_1 +# Graph fragment: +# %sum_1 : Tensor "i64[8, 1, 16, 16][256, 2048, 16, 1]cuda:7" = PlaceHolder[target=sum_1] +# %sum_2 : Tensor "i64[8, 1, 16][16, 128, 1]cuda:7" = PlaceHolder[target=sum_2] +# %sum_3 : Tensor "i64[8, 1, 16][16, 128, 1]cuda:7" = PlaceHolder[target=sum_3] +# %buf2 : Tensor "i16[8, 1, 16, 16][256, 2048, 16, 1]cuda:7" = PlaceHolder[target=buf2] +# %convert_element_type_3 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:7" = PlaceHolder[target=convert_element_type_3] +# %convert_element_type_4 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:7" = PlaceHolder[target=convert_element_type_4] +# %index_put : Tensor "i32[8, 1, 16, 17][272, 272, 17, 1]cuda:7" = PlaceHolder[target=index_put] +# %buf4 : Tensor "i16[8, 1, 16, 16][256, 2048, 16, 1]cuda:7" = PlaceHolder[target=buf4] +# %convert_element_type_6 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:7" = PlaceHolder[target=convert_element_type_6] +# %convert_element_type_7 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:7" = PlaceHolder[target=convert_element_type_7] +# %index_put_1 : Tensor "i32[8, 1, 16, 17][272, 272, 17, 1]cuda:7" = PlaceHolder[target=index_put_1] +# %gt : Tensor "b8[8, 1, 16, 16][256, 256, 16, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {}) +# %lt_3 : Tensor "b8[8, 1, 16, 16][256, 256, 16, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.lt.Scalar](args = (%sum_1, 16384), kwargs = {}) +# %bitwise_and_4 : Tensor "b8[8, 1, 16, 16][256, 256, 16, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%gt, %lt_3), kwargs = {}) +# %convert_element_type : Tensor "i8[8, 1, 16, 16][256, 256, 16, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%bitwise_and_4, torch.int8), kwargs = {}) +# %convert_element_type_2 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:7"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%convert_element_type, torch.int32), kwargs = {}) +# %sort : [num_users=1] = call_function[target=torch.ops.aten.sort.stable](args = (%convert_element_type_2,), kwargs = {stable: True, descending: True}) +# %eq_1 : Tensor "b8[8, 1, 16, 16][256, 256, 16, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%sum_1, 16384), kwargs = {}) +# %convert_element_type_1 : Tensor "i8[8, 1, 16, 16][256, 256, 16, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%eq_1, torch.int8), kwargs = {}) +# %convert_element_type_5 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:7"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%convert_element_type_1, torch.int32), kwargs = {}) +# %sort_1 : [num_users=1] = call_function[target=torch.ops.aten.sort.stable](args = (%convert_element_type_5,), kwargs = {stable: True, descending: True}) +# %full_default_1 : Tensor "i32[8, 1, 16, 17][272, 272, 17, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 1, 16, 17], 0), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:7, pin_memory: False}) +# %iota_7 : Tensor "i64[8][1]cuda:7"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (8,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:7, requires_grad: False}) +# %unsqueeze_4 : Tensor "i64[8, 1][1, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_7, -1), kwargs = {}) +# %unsqueeze_5 : Tensor "i64[8, 1, 1][1, 1, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_4, -1), kwargs = {}) +# %unsqueeze_6 : Tensor "i64[8, 1, 1, 1][1, 1, 1, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_5, -1), kwargs = {}) +# %iota_6 : Tensor "i64[1][1]cuda:7"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (1,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:7, requires_grad: False}) +# %unsqueeze_2 : Tensor "i64[1, 1][1, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_6, -1), kwargs = {}) +# %unsqueeze_3 : Tensor "i64[1, 1, 1][1, 1, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_2, -1), kwargs = {}) +# %iota_4 : Tensor "i32[16][1]cuda:7"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (16,), kwargs = {start: 0, step: 1, dtype: torch.int32, device: cuda:7, requires_grad: False}) +# %unsqueeze : Tensor "i32[16, 1][1, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_4, -1), kwargs = {}) +# %iota_5 : Tensor "i32[16][1]cuda:7"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (16,), kwargs = {start: 0, step: 1, dtype: torch.int32, device: cuda:7, requires_grad: False}) +# %sum_2 : Tensor "i64[8, 1, 16][16, 16, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%convert_element_type_2, [-1]), kwargs = {}) +# %convert_element_type_3 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:7"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_2, torch.int32), kwargs = {}) +# %unsqueeze_1 : Tensor "i32[8, 1, 16, 1][16, 16, 1, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_3, 3), kwargs = {}) +# %lt_4 : Tensor "b8[8, 1, 16, 16][256, 256, 16, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_5, %unsqueeze_1), kwargs = {}) +# %convert_element_type_4 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:7"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_1, torch.int32), kwargs = {}) +# %full_default_2 : Tensor "i32[][]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([], 16), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:7, pin_memory: False}) +# %where : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%lt_4, %convert_element_type_4, %full_default_2), kwargs = {}) +# %full_default_3 : Tensor "i32[8, 1, 1, 1][1, 1, 1, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 1, 1, 1], 1), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:7, pin_memory: False}) +# %index_put : Tensor "i32[8, 1, 16, 17][272, 272, 17, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.index_put_.default](args = (%full_default_1, [%unsqueeze_6, %unsqueeze_3, %unsqueeze, %where], %full_default_3), kwargs = {}) +# %full_default_4 : Tensor "i32[8, 1, 16, 17][272, 272, 17, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 1, 16, 17], 0), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:7, pin_memory: False}) +# %iota_11 : Tensor "i64[8][1]cuda:7"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (8,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:7, requires_grad: False}) +# %unsqueeze_11 : Tensor "i64[8, 1][1, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_11, -1), kwargs = {}) +# %unsqueeze_12 : Tensor "i64[8, 1, 1][1, 1, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_11, -1), kwargs = {}) +# %unsqueeze_13 : Tensor "i64[8, 1, 1, 1][1, 1, 1, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_12, -1), kwargs = {}) +# %iota_10 : Tensor "i64[1][1]cuda:7"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (1,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:7, requires_grad: False}) +# %unsqueeze_9 : Tensor "i64[1, 1][1, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_10, -1), kwargs = {}) +# %unsqueeze_10 : Tensor "i64[1, 1, 1][1, 1, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_9, -1), kwargs = {}) +# %iota_8 : Tensor "i32[16][1]cuda:7"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (16,), kwargs = {start: 0, step: 1, dtype: torch.int32, device: cuda:7, requires_grad: False}) +# %unsqueeze_7 : Tensor "i32[16, 1][1, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_8, -1), kwargs = {}) +# %iota_9 : Tensor "i32[16][1]cuda:7"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (16,), kwargs = {start: 0, step: 1, dtype: torch.int32, device: cuda:7, requires_grad: False}) +# %sum_3 : Tensor "i64[8, 1, 16][16, 16, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%convert_element_type_5, [-1]), kwargs = {}) +# %convert_element_type_6 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:7"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_3, torch.int32), kwargs = {}) +# %unsqueeze_8 : Tensor "i32[8, 1, 16, 1][16, 16, 1, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_6, 3), kwargs = {}) +# %lt_5 : Tensor "b8[8, 1, 16, 16][256, 256, 16, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_9, %unsqueeze_8), kwargs = {}) +# %convert_element_type_7 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:7"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_3, torch.int32), kwargs = {}) +# %full_default_5 : Tensor "i32[][]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([], 16), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:7, pin_memory: False}) +# %where_1 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%lt_5, %convert_element_type_7, %full_default_5), kwargs = {}) +# %full_default_6 : Tensor "i32[8, 1, 1, 1][1, 1, 1, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 1, 1, 1], 1), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:7, pin_memory: False}) +# %index_put_1 : Tensor "i32[8, 1, 16, 17][272, 272, 17, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.index_put_.default](args = (%full_default_4, [%unsqueeze_13, %unsqueeze_10, %unsqueeze_7, %where_1], %full_default_6), kwargs = {}) +# return %buf2,%buf4,%sum_2,%sum_3,%convert_element_type_3,%convert_element_type_6,%convert_element_type_4,%buf9,%convert_element_type_7,%buf16 +triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2 = async_compile.triton('triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 128, 'r0_': 16}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr4': '*i32', 'out_ptr5': '*i32', 'out_ptr6': '*i32', 'out_ptr7': '*i32', 'out_ptr8': '*i32', 'out_ptr9': '*i32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2', 'mutated_arg_names': ['out_ptr7', 'out_ptr9'], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2(in_ptr0, out_ptr4, out_ptr5, out_ptr6, out_ptr7, out_ptr8, out_ptr9, xnumel, r0_numel, XBLOCK : tl.constexpr): + xnumel = 128 + r0_numel = 16 + R0_BLOCK: tl.constexpr = 16 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + x0 = xindex + tmp0 = tl.load(in_ptr0 + (r0_1 + 16*x0), xmask, other=0.0) + tmp1 = tl.full([1, 1], 0, tl.int64) + tmp2 = tmp0 > tmp1 + tmp3 = tl.full([1, 1], 16384, tl.int64) + tmp4 = tmp0 < tmp3 + tmp5 = tmp2 & tmp4 + tmp6 = tmp5.to(tl.int8) + tmp7 = tmp6.to(tl.int32) + tmp8 = r0_1 + tmp9 = tmp8.to(tl.int16) + tmp10 = tl.broadcast_to(tmp7, [XBLOCK, R0_BLOCK]) + tmp11 = tl.broadcast_to(tmp9, [XBLOCK, R0_BLOCK]) + tmp12, tmp13, = triton_helpers.sort_with_index(tmp10, tmp11, None, 1, stable=True, descending=True) + tmp14 = tmp0 == tmp3 + tmp15 = tmp14.to(tl.int8) + tmp16 = tmp15.to(tl.int32) + tmp17 = tl.broadcast_to(tmp16, [XBLOCK, R0_BLOCK]) + tmp18, tmp19, = triton_helpers.sort_with_index(tmp17, tmp11, None, 1, stable=True, descending=True) + tmp20 = tmp7.to(tl.int64) + tmp21 = tl.broadcast_to(tmp20, [XBLOCK, R0_BLOCK]) + tmp23 = tl.where(xmask, tmp21, 0) + tmp24 = tl.sum(tmp23, 1)[:, None].to(tl.int64) + tmp25 = tmp16.to(tl.int64) + tmp26 = tl.broadcast_to(tmp25, [XBLOCK, R0_BLOCK]) + tmp28 = tl.where(xmask, tmp26, 0) + tmp29 = tl.sum(tmp28, 1)[:, None].to(tl.int64) + tmp30 = tmp24.to(tl.int32) + tmp31 = tmp29.to(tl.int32) + tmp32 = tmp13.to(tl.int64) + tmp33 = tmp32.to(tl.int32) + tmp34 = tmp8 < tmp30 + tmp35 = tl.full([1, 1], 16, tl.int32) + tmp36 = tl.where(tmp34, tmp33, tmp35) + tmp37 = tl.full([XBLOCK, R0_BLOCK], 17, tl.int32) + tmp38 = tmp36 + tmp37 + tmp39 = tmp36 < 0 + tmp40 = tl.where(tmp39, tmp38, tmp36) + tl.device_assert(((0 <= tmp40) & (tmp40 < 17)) | ~(xmask), "index out of bounds: 0 <= tmp40 < 17") + tmp42 = tl.full([1, 1], 1, tl.int32) + tmp43 = tmp19.to(tl.int64) + tmp44 = tmp43.to(tl.int32) + tmp45 = tmp8 < tmp31 + tmp46 = tl.where(tmp45, tmp44, tmp35) + tmp47 = tmp46 + tmp37 + tmp48 = tmp46 < 0 + tmp49 = tl.where(tmp48, tmp47, tmp46) + tl.device_assert(((0 <= tmp49) & (tmp49 < 17)) | ~(xmask), "index out of bounds: 0 <= tmp49 < 17") + tl.store(out_ptr4 + (x0), tmp30, xmask) + tl.store(out_ptr5 + (x0), tmp31, xmask) + tl.store(out_ptr6 + (r0_1 + 16*x0), tmp33, xmask) + tl.store(out_ptr7 + (tl.broadcast_to(tmp40 + 17*x0, [XBLOCK, R0_BLOCK])), tmp42, xmask) + tl.store(out_ptr8 + (r0_1 + 16*x0), tmp44, xmask) + tl.store(out_ptr9 + (tl.broadcast_to(tmp49 + 17*x0, [XBLOCK, R0_BLOCK])), tmp42, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bi/cbiwa42zuoemjhwwkub6gypxcryfi2fbcigroxmaahfipc6cwcmf.py +# Topologically Sorted Source Nodes: [batched_outputs_3, transpose, col_indices_2, q_indices, num_blocks_in_row_2, q_num_blocks], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sort, aten._to_copy, aten.sum] +# Source node to ATen node mapping: +# batched_outputs_3 => clone_4, slice_2 +# col_indices_2 => sort_2 +# num_blocks_in_row_2 => sum_4 +# q_indices => clone_6, convert_element_type_9 +# q_num_blocks => convert_element_type_8 +# transpose => permute_1 +# Graph fragment: +# %buf9 : Tensor "i32[8, 1, 16, 17][272, 272, 17, 1]cuda:7" = PlaceHolder[target=buf9] +# %buf11 : Tensor "i16[8, 1, 16, 16][256, 2048, 16, 1]cuda:7" = PlaceHolder[target=buf11] +# %sum_4 : Tensor "i64[8, 1, 16][16, 128, 1]cuda:7" = PlaceHolder[target=sum_4] +# %slice_2 : Tensor "i32[8, 1, 16, 16][272, 272, 17, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%index_put, 3, 0, 16), kwargs = {}) +# %clone_4 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%slice_2,), kwargs = {memory_format: torch.contiguous_format}) +# %permute_1 : Tensor "i32[8, 1, 16, 16][256, 256, 1, 16]cuda:7"[num_users=2] = call_function[target=torch.ops.aten.permute.default](args = (%clone_4, [0, 1, 3, 2]), kwargs = {}) +# %sort_2 : [num_users=1] = call_function[target=torch.ops.aten.sort.stable](args = (%permute_1,), kwargs = {stable: True, descending: True}) +# %convert_element_type_9 : Tensor "i32[8, 1, 16, 16][256, 256, 1, 16]cuda:7"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_5, torch.int32), kwargs = {}) +# %clone_6 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%convert_element_type_9,), kwargs = {memory_format: torch.contiguous_format}) +# %sum_4 : Tensor "i64[8, 1, 16][16, 16, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%permute_1, [-1]), kwargs = {}) +# %convert_element_type_8 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_4, torch.int32), kwargs = {}) +# return %buf11,%sum_4,%clone_6,%convert_element_type_8 +triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3 = async_compile.triton('triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 128, 'r0_': 16}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr2': '*i32', 'out_ptr3': '*i32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 1024, 'r0_': 16384}} +) +@triton.jit +def triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3(in_ptr0, out_ptr2, out_ptr3, xnumel, r0_numel, XBLOCK : tl.constexpr): + xnumel = 128 + r0_numel = 16 + R0_BLOCK: tl.constexpr = 16 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + x0 = (xindex % 16) + x1 = xindex // 16 + x3 = xindex + tmp0 = tl.load(in_ptr0 + (x0 + 17*r0_2 + 272*x1), xmask, other=0.0) + tmp1 = r0_2 + tmp2 = tmp1.to(tl.int16) + tmp3 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp4 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5, tmp6, = triton_helpers.sort_with_index(tmp3, tmp4, None, 1, stable=True, descending=True) + tmp7 = tmp0.to(tl.int64) + tmp8 = tl.broadcast_to(tmp7, [XBLOCK, R0_BLOCK]) + tmp10 = tl.where(xmask, tmp8, 0) + tmp11 = tl.sum(tmp10, 1)[:, None].to(tl.int64) + tmp12 = tmp6.to(tl.int64) + tmp13 = tmp12.to(tl.int32) + tmp14 = tmp11.to(tl.int32) + tl.store(out_ptr2 + (r0_2 + 16*x3), tmp13, xmask) + tl.store(out_ptr3 + (x3), tmp14, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, = args + args.clear() + assert_size_stride(arg0_1, (8, ), (1, )) + with torch.cuda._DeviceGuard(7): + torch.cuda.set_device(7) + buf0 = empty_strided_cuda((8, 1, 16, 16), (256, 2048, 16, 1), torch.int64) + # Topologically Sorted Source Nodes: [result_1, m, causal_mask, n, b, index, lt, padding_mask, index_1, lt_1, and_2, suffix_mask, remainder, index_2, padding_mask_1, and_3, and_4, sub, remainder_1, diagnol_mask, result_2, batched_outputs_2, mask_2, mask_3, mask_block_sum], Original ATen: [aten.view, aten.arange, aten.ge, aten.index, aten.lt, aten.bitwise_and, aten.bitwise_or, aten.remainder, aten.sub, aten.eq, aten.permute, aten.sum] + stream7 = get_raw_stream(7) + triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0.run(arg0_1, buf0, 2048, 16384, stream=stream7) + del arg0_1 + buf15 = empty_strided_cuda((8, 1, 16, 17), (272, 272, 17, 1), torch.int32) + # Topologically Sorted Source Nodes: [dense_mask_4], Original ATen: [aten.new_zeros] + stream7 = get_raw_stream(7) + triton_poi_fused_new_zeros_1.run(buf15, 2176, stream=stream7) + buf8 = empty_strided_cuda((8, 1, 16, 17), (272, 272, 17, 1), torch.int32) + # Topologically Sorted Source Nodes: [dense_mask_2], Original ATen: [aten.new_zeros] + stream7 = get_raw_stream(7) + triton_poi_fused_new_zeros_1.run(buf8, 2176, stream=stream7) + buf6 = empty_strided_cuda((8, 1, 16), (16, 16, 1), torch.int32) + buf13 = empty_strided_cuda((8, 1, 16), (16, 16, 1), torch.int32) + buf7 = empty_strided_cuda((8, 1, 16, 16), (256, 256, 16, 1), torch.int32) + buf14 = empty_strided_cuda((8, 1, 16, 16), (256, 256, 16, 1), torch.int32) + # Topologically Sorted Source Nodes: [gt, lt_3, partial_blocks, partial_blocks_1, dense_mask, col_indices, full_blocks, full_blocks_1, dense_mask_1, col_indices_1, dense_mask_2, setitem, arange_4, row_indices, col_range, num_blocks_in_row, child_3, unsqueeze_1, index_mask, child_4, valid_indices, dense_mask_4, setitem_1, arange_6, row_indices_1, col_range_1, num_blocks_in_row_1, child_7, unsqueeze_3, index_mask_1, child_8, valid_indices_1], Original ATen: [aten.gt, aten.lt, aten.bitwise_and, aten._to_copy, aten.sort, aten.eq, aten.new_zeros, aten.arange, aten.unsqueeze, aten.sum, aten.scalar_tensor, aten.where, aten.view, aten.index_put] + stream7 = get_raw_stream(7) + triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.run(buf0, buf6, buf13, buf7, buf8, buf14, buf15, 128, 16, stream=stream7) + del buf0 + buf22 = empty_strided_cuda((8, 1, 16, 16), (256, 256, 16, 1), torch.int32) + buf24 = empty_strided_cuda((8, 1, 16), (16, 16, 1), torch.int32) + # Topologically Sorted Source Nodes: [batched_outputs_3, transpose, col_indices_2, q_indices, num_blocks_in_row_2, q_num_blocks], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sort, aten._to_copy, aten.sum] + stream7 = get_raw_stream(7) + triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3.run(buf8, buf22, buf24, 128, 16, stream=stream7) + del buf8 + buf19 = empty_strided_cuda((8, 1, 16, 16), (256, 256, 16, 1), torch.int32) + buf21 = empty_strided_cuda((8, 1, 16), (16, 16, 1), torch.int32) + # Topologically Sorted Source Nodes: [batched_outputs_5, transpose_1, col_indices_3, full_q_indices, num_blocks_in_row_3, full_q_num_blocks], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sort, aten._to_copy, aten.sum] + stream7 = get_raw_stream(7) + triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3.run(buf15, buf19, buf21, 128, 16, stream=stream7) + del buf15 + return (buf19, buf21, buf22, buf24, buf14, buf13, buf7, buf6, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = rand_strided((8, ), (1, ), device='cuda:7', dtype=torch.int64) + fn = lambda: call([arg0_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/ne/cneup3jgvtc2kyazyreujmpq5tpdftpmosibbc6a6rhgokp3hfek.py b/SpecForge-ext/cache/compiled_kernels/ne/cneup3jgvtc2kyazyreujmpq5tpdftpmosibbc6a6rhgokp3hfek.py new file mode 100644 index 0000000000000000000000000000000000000000..1cb6577bdbd330170083bbbf2e855a936c46dd64 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/ne/cneup3jgvtc2kyazyreujmpq5tpdftpmosibbc6a6rhgokp3hfek.py @@ -0,0 +1,47 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 16384, 'r0_': 32768}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*i64', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(1,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_argmax_1(in_ptr0, out_ptr0, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % ks0) + x1 = xindex // ks0 + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + x3 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_2 + 32000*x0 + ks1*x1), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask & xmask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask & xmask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tl.store(out_ptr0 + (x3), tmp2, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/ne/cnevwr22hrepn7wpkpcrgviodbjuo2jfrjotgishyima7j433p77.py b/SpecForge-ext/cache/compiled_kernels/ne/cnevwr22hrepn7wpkpcrgviodbjuo2jfrjotgishyima7j433p77.py new file mode 100644 index 0000000000000000000000000000000000000000..5315b498d9474e8a80ce0fdde02cc06f9757793a --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/ne/cnevwr22hrepn7wpkpcrgviodbjuo2jfrjotgishyima7j433p77.py @@ -0,0 +1,47 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*i64', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(1,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_argmax_1(in_ptr0, out_ptr0, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % ks0) + x1 = xindex // ks0 + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + x3 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_2 + 32000*x0 + ks1*x1), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask & xmask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask & xmask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tl.store(out_ptr0 + (x3), tmp2, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/nm/cnmfzfqknv64q5haullsq4qsxjk26ffzp3j3cbgn47tsj7swt6e6.py b/SpecForge-ext/cache/compiled_kernels/nm/cnmfzfqknv64q5haullsq4qsxjk26ffzp3j3cbgn47tsj7swt6e6.py new file mode 100644 index 0000000000000000000000000000000000000000..0e35bc919c115cc048f50dc6c43e2e5ac95497da --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/nm/cnmfzfqknv64q5haullsq4qsxjk26ffzp3j3cbgn47tsj7swt6e6.py @@ -0,0 +1,303 @@ +# AOT ID: ['7_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/u5/cu55o4f4khg2wuonhaoogm7cwe7beivg5otgutgnrv3xkelakvcz.py +# Topologically Sorted Source Nodes: [argmax], Original ATen: [aten.argmax] +# Source node to ATen node mapping: +# argmax => argmax +# Graph fragment: +# %arg0_1 : Tensor "bf16[2, 2048, 32000][65536000, 32000, 1]cuda:6" = PlaceHolder[target=arg0_1] +# %argmax : Tensor "i64[2, 2048][2048, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.argmax.default](args = (%arg0_1, -1), kwargs = {}) +# return %argmax +triton_red_fused_argmax_0 = async_compile.triton('triton_red_fused_argmax_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 65536, 'r0_': 262144000}} +) +@triton.jit +def triton_red_fused_argmax_0(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 4096 + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tl.store(out_ptr0 + (x0), tmp2, None) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ii/ciiz7wynjvqkn6uv5csahwryt5x2d664u4o7ugmepfcsfcniut4v.py +# Topologically Sorted Source Nodes: [argmax_1], Original ATen: [aten.argmax] +# Source node to ATen node mapping: +# argmax_1 => argmax_1 +# Graph fragment: +# %arg1_1 : Tensor "f32[2, 2048, 32000][65760000, 32000, 1]cuda:6" = PlaceHolder[target=arg1_1] +# %argmax_1 : Tensor "i64[2, 2048][2048, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.argmax.default](args = (%arg1_1, -1), kwargs = {}) +# return %argmax_1 +triton_red_fused_argmax_1 = async_compile.triton('triton_red_fused_argmax_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 65536, 'r0_': 524288000}} +) +@triton.jit +def triton_red_fused_argmax_1(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 4096 + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % 2048) + x1 = xindex // 2048 + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + x3 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_2 + 32000*x0 + 65760000*x1), r0_mask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tl.store(out_ptr0 + (x3), tmp2, None) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ci/cciixjrxmqxkrua6qu2qp43yj5dpjjjg6nmygdfgjjw2uufo6njh.py +# Topologically Sorted Source Nodes: [eq, squeeze, mul, sum_1, sum_2, clamp_min, truediv], Original ATen: [aten.eq, aten.squeeze, aten.mul, aten.sum, aten.clamp_min, aten.div] +# Source node to ATen node mapping: +# clamp_min => clamp_min +# eq => eq +# mul => mul +# squeeze => squeeze +# sum_1 => sum_1 +# sum_2 => sum_2 +# truediv => div +# Graph fragment: +# %argmax : Tensor "i64[2, 2048][2048, 1]cuda:6" = PlaceHolder[target=argmax] +# %argmax_1 : Tensor "i64[2, 2048][2048, 1]cuda:6" = PlaceHolder[target=argmax_1] +# %arg2_1 : Tensor "i64[2, 2048, 1][2048, 1, 1]cuda:6" = PlaceHolder[target=arg2_1] +# %arg3_1 : Tensor "i64[2, 2048, 1][2048, 1, 1]cuda:6" = PlaceHolder[target=arg3_1] +# %sum_1 : Tensor "i64[][]cuda:6" = PlaceHolder[target=sum_1] +# %sum_2 : Tensor "i64[][]cuda:6" = PlaceHolder[target=sum_2] +# %eq : Tensor "b8[2, 2048][2048, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.eq.Tensor](args = (%argmax, %argmax_1), kwargs = {}) +# %squeeze : Tensor "i64[2, 2048][2048, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%arg2_1, -1), kwargs = {}) +# %mul : Tensor "i64[2, 2048][2048, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%eq, %squeeze), kwargs = {}) +# %sum_1 : Tensor "i64[][]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul,), kwargs = {}) +# %sum_2 : Tensor "i64[][]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%arg3_1,), kwargs = {}) +# %clamp_min : Tensor "f32[][]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.clamp_min.default](args = (%sum_2, 1e-06), kwargs = {}) +# %div : Tensor "f32[][]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%sum_1, %clamp_min), kwargs = {}) +# return %sum_1,%sum_2,%div +triton_red_fused_clamp_min_div_eq_mul_squeeze_sum_2 = async_compile.triton('triton_red_fused_clamp_min_div_eq_mul_squeeze_sum_2', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 1, 'r0_': 4096}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i64', 'in_ptr2': '*i64', 'in_ptr3': '*i64', 'out_ptr2': '*fp32', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {'xnumel': 1}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_clamp_min_div_eq_mul_squeeze_sum_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'r0_': 131072}} +) +@triton.jit +def triton_red_fused_clamp_min_div_eq_mul_squeeze_sum_2(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 1 + r0_numel = 4096 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + _tmp7 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_0 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.load(in_ptr1 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp4 = tl.load(in_ptr2 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp2 = tmp0 == tmp1 + tmp3 = tmp2.to(tl.int64) + tmp5 = tmp3 * tmp4 + tmp6 = tl.broadcast_to(tmp5, [XBLOCK, R0_BLOCK]) + tmp8 = _tmp7 + tmp6 + _tmp7 = tl.where(r0_mask, tmp8, _tmp7) + tmp7 = tl.sum(_tmp7, 1)[:, None] + _tmp11 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_0 = r0_index + tmp9 = tl.load(in_ptr3 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp10 = tl.broadcast_to(tmp9, [XBLOCK, R0_BLOCK]) + tmp12 = _tmp11 + tmp10 + _tmp11 = tl.where(r0_mask, tmp12, _tmp11) + tmp11 = tl.sum(_tmp11, 1)[:, None] + tmp13 = tmp7.to(tl.float32) + tmp14 = tmp11.to(tl.float32) + tmp15 = 1e-06 + tmp16 = triton_helpers.maximum(tmp14, tmp15) + tmp17 = (tmp13 / tmp16) + tl.store(out_ptr2 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp17, None) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1, arg2_1, arg3_1 = args + args.clear() + assert_size_stride(arg0_1, (2, 2048, 32000), (65536000, 32000, 1)) + assert_size_stride(arg1_1, (2, 2048, 32000), (65760000, 32000, 1)) + assert_size_stride(arg2_1, (2, 2048, 1), (2048, 1, 1)) + assert_size_stride(arg3_1, (2, 2048, 1), (2048, 1, 1)) + with torch.cuda._DeviceGuard(6): + torch.cuda.set_device(6) + buf0 = empty_strided_cuda((2, 2048), (2048, 1), torch.int64) + # Topologically Sorted Source Nodes: [argmax], Original ATen: [aten.argmax] + stream6 = get_raw_stream(6) + triton_red_fused_argmax_0.run(arg0_1, buf0, 4096, 32000, stream=stream6) + del arg0_1 + buf1 = empty_strided_cuda((2, 2048), (2048, 1), torch.int64) + # Topologically Sorted Source Nodes: [argmax_1], Original ATen: [aten.argmax] + stream6 = get_raw_stream(6) + triton_red_fused_argmax_1.run(arg1_1, buf1, 4096, 32000, stream=stream6) + del arg1_1 + buf4 = empty_strided_cuda((), (), torch.float32) + # Topologically Sorted Source Nodes: [eq, squeeze, mul, sum_1, sum_2, clamp_min, truediv], Original ATen: [aten.eq, aten.squeeze, aten.mul, aten.sum, aten.clamp_min, aten.div] + stream6 = get_raw_stream(6) + triton_red_fused_clamp_min_div_eq_mul_squeeze_sum_2.run(buf0, buf1, arg2_1, arg3_1, buf4, 1, 4096, stream=stream6) + del arg2_1 + del arg3_1 + del buf0 + del buf1 + return (buf4, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = rand_strided((2, 2048, 32000), (65536000, 32000, 1), device='cuda:6', dtype=torch.bfloat16) + arg1_1 = rand_strided((2, 2048, 32000), (65760000, 32000, 1), device='cuda:6', dtype=torch.float32) + arg2_1 = rand_strided((2, 2048, 1), (2048, 1, 1), device='cuda:6', dtype=torch.int64) + arg3_1 = rand_strided((2, 2048, 1), (2048, 1, 1), device='cuda:6', dtype=torch.int64) + fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/np/ac98176c2e734f9b04f7497a2dda5d7430379f7626965c7bebdb73fea7cc1119.best_config b/SpecForge-ext/cache/compiled_kernels/np/ac98176c2e734f9b04f7497a2dda5d7430379f7626965c7bebdb73fea7cc1119.best_config new file mode 100644 index 0000000000000000000000000000000000000000..3fd0170715dceebca88888383bbe8f15eedaaeca --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/np/ac98176c2e734f9b04f7497a2dda5d7430379f7626965c7bebdb73fea7cc1119.best_config @@ -0,0 +1 @@ +{"XBLOCK": 128, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 14, "triton_cache_hash": "S3UH64TOYTN473KAATRMGKZ5SLQ46EZYJVPR6TIL7QNMYCB3MSMA"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/np/cnp6x45zcthhoa6foyolgsvdqvn4rmfofom26u7qk7euk6rr6ib2.py b/SpecForge-ext/cache/compiled_kernels/np/cnp6x45zcthhoa6foyolgsvdqvn4rmfofom26u7qk7euk6rr6ib2.py new file mode 100644 index 0000000000000000000000000000000000000000..b9cbb3d38331873ebdb4979a3704ea49190b6c42 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/np/cnp6x45zcthhoa6foyolgsvdqvn4rmfofom26u7qk7euk6rr6ib2.py @@ -0,0 +1,40 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 512}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i32', 'out_ptr0': '*i32', 'out_ptr1': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3', 'mutated_arg_names': ['out_ptr1'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3(in_ptr0, in_ptr1, out_ptr0, out_ptr1, ks0, ks1, ks2, ks3, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = (xindex % ks0) + x1 = ((xindex // ks0) % ks1) + x2 = xindex // ks2 + x3 = xindex // ks0 + tmp0 = tl.load(in_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))) + x2*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1)))), xmask, eviction_policy='evict_last') + tmp2 = tl.load(in_ptr1 + (x3), xmask, eviction_policy='evict_last') + tmp1 = tmp0.to(tl.int32) + tmp3 = x0 + tmp4 = tmp3 < tmp2 + tmp5 = ks0 + tmp6 = tl.where(tmp4, tmp1, tmp5) + tmp7 = 1 + ks0 + tmp8 = tmp6 + tmp7 + tmp9 = tmp6 < 0 + tmp10 = tl.where(tmp9, tmp8, tmp6) + tl.device_assert(((0 <= tmp10) & (tmp10 < 1 + (triton_helpers.div_floor_integer(127 + ks3, 128)))) | ~(xmask), "index out of bounds: 0 <= tmp10 < 1 + (triton_helpers.div_floor_integer(127 + ks3, 128))") + tmp12 = tl.full([1], 1, tl.int32) + tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))) + x2*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1)))), tmp1, xmask) + tl.store(out_ptr1 + (tmp10 + x3 + ks0*x3), tmp12, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/ns/cnsrpjlovgxpe5uxsjdupxszbn2v5ie4attqhksar2iayjruaiwi.py b/SpecForge-ext/cache/compiled_kernels/ns/cnsrpjlovgxpe5uxsjdupxszbn2v5ie4attqhksar2iayjruaiwi.py new file mode 100644 index 0000000000000000000000000000000000000000..3190e8feda3638a70292e2d73dd8a22b8974405f --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/ns/cnsrpjlovgxpe5uxsjdupxszbn2v5ie4attqhksar2iayjruaiwi.py @@ -0,0 +1,57 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 262144}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_out_ptr0': '*i64', 'in_ptr0': '*bf16', 'in_ptr1': '*i1', 'in_ptr2': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 4096 + r0_numel = 151936 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 151936*x0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tmp11 = tl.load(in_ptr2 + (x0), None, eviction_policy='evict_last') + tmp3 = tl.full([XBLOCK, 1], 151936, tl.int32) + tmp4 = tmp2 + tmp3 + tmp5 = tmp2 < 0 + tmp6 = tl.where(tmp5, tmp4, tmp2) + tl.device_assert((0 <= tmp6) & (tmp6 < 151936), "index out of bounds: 0 <= tmp6 < 151936") + tmp8 = tl.load(in_ptr1 + (tmp6), None, eviction_policy='evict_last').to(tl.int1) + tmp9 = tmp8.to(tl.int32) + tmp10 = tmp9.to(tl.int64) + tmp12 = tmp10 * tmp11 + tl.debug_barrier() + tl.store(in_out_ptr0 + (x0), tmp12, None) diff --git a/SpecForge-ext/cache/compiled_kernels/nu/cnuelpak7thljvq6ivqfbdypnoenh6jqw66w2vyl75sdyur5zxzw.py b/SpecForge-ext/cache/compiled_kernels/nu/cnuelpak7thljvq6ivqfbdypnoenh6jqw66w2vyl75sdyur5zxzw.py new file mode 100644 index 0000000000000000000000000000000000000000..a3abe3d605a728e2c2a3aa4884ce2a2965fc47bc --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/nu/cnuelpak7thljvq6ivqfbdypnoenh6jqw66w2vyl75sdyur5zxzw.py @@ -0,0 +1,835 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_zeros_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks0, 128*ks0, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks0, 128*ks0, 128, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 8388608, 262144, 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 8388608, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks0, 128*ks0, 128, 1 + + ZQ = 8 + HQ = 32 + HKV = 8 + Q_LEN = 2048 + ZKV = 8 + KV_LEN = ks0 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 8 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 16*ks1 + stride_kv_idx_m = ks1 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = ks2 + stride_q_idx_h = 16*ks3 + stride_q_idx_n = 16 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks0 + 1024*off_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = ks0 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr16 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp39 = (ds) + grad_scores = tmp39 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = ks0 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp40 = (qkT) + post_mod_scores = tmp40 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp41 = tl.full([1], False, tl.int1) + tmp42 = (m) + tmp43 = (n) + tmp44 = tmp42 >= tmp43 + tmp45 = tmp43.to(tl.int64) + tmp46 = (off_z) + tmp47 = tl.load(in_ptr16 + tmp46) + tmp48 = tmp45 < tmp47 + tmp49 = tmp42.to(tl.int64) + tmp50 = tmp49 < tmp47 + tmp51 = tmp48 & tmp50 + tmp52 = tmp44 & tmp51 + tmp53 = tmp41 | tmp52 + tmp54 = tl.full([1], 2048, tl.int32) + tmp55 = tmp43 >= tmp54 + tmp56 = (tmp43 % tmp54) + tmp57 = tl.full([1], 0, tl.int32) + tmp58 = tmp56 != tmp57 + tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0 + tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0 + tmp61 = tmp59 != tmp60 + tmp62 = tmp58 & tmp61 + tmp63 = tmp56 + tmp54 + tmp64 = tl.where(tmp62, tmp63, tmp56) + tmp65 = tmp64.to(tl.int64) + tmp66 = tmp65 < tmp47 + tmp67 = tmp55 & tmp66 + tmp68 = tmp43 - tmp42 + tmp69 = (tmp68 % tmp54) + tmp70 = tmp69 != tmp57 + tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0 + tmp72 = tmp71 != tmp60 + tmp73 = tmp70 & tmp72 + tmp74 = tmp69 + tmp54 + tmp75 = tl.where(tmp73, tmp74, tmp69) + tmp76 = tmp75 == tmp57 + tmp77 = tmp67 & tmp76 + tmp78 = tmp53 | tmp77 + mask_mod_output = tmp78 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp79 = (dsT) + grad_scores = tmp79 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) diff --git a/SpecForge-ext/cache/compiled_kernels/nu/cnuhkncuuhx44foku26ustomd5mtfdfzy2axanrafccdeth6otsg.py b/SpecForge-ext/cache/compiled_kernels/nu/cnuhkncuuhx44foku26ustomd5mtfdfzy2axanrafccdeth6otsg.py new file mode 100644 index 0000000000000000000000000000000000000000..f02d81ab200afa8294bc2f5bd6799a1fa021c373 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/nu/cnuhkncuuhx44foku26ustomd5mtfdfzy2axanrafccdeth6otsg.py @@ -0,0 +1,71 @@ +# AOT ID: ['3_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1, arg2_1 = args + args.clear() + s21 = arg0_1 + assert_size_stride(arg1_1, (1, 1, 40980, 128), (5245440, 5245440, 128, 1)) + assert_size_stride(arg2_1, (1, 1, 40980, 128), (5245440, 5245440, 128, 1)) + return (reinterpret_tensor(arg1_1, (1, 1, s21, 128), (5245440, 5245440, 128, 1), 0), reinterpret_tensor(arg2_1, (1, 1, s21, 128), (5245440, 5245440, 128, 1), 0), ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = 2048 + arg1_1 = rand_strided((1, 1, 40980, 128), (5245440, 5245440, 128, 1), device='cuda:6', dtype=torch.bfloat16) + arg2_1 = rand_strided((1, 1, 40980, 128), (5245440, 5245440, 128, 1), device='cuda:6', dtype=torch.bfloat16) + fn = lambda: call([arg0_1, arg1_1, arg2_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/nx/cnxyzzrtkb2x53y2inkm35xscgdtoilsp4ahxz3aomx4a2ng4rih.py b/SpecForge-ext/cache/compiled_kernels/nx/cnxyzzrtkb2x53y2inkm35xscgdtoilsp4ahxz3aomx4a2ng4rih.py new file mode 100644 index 0000000000000000000000000000000000000000..6e29e97c99ea29ee8a3651afe121e076bb59f527 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/nx/cnxyzzrtkb2x53y2inkm35xscgdtoilsp4ahxz3aomx4a2ng4rih.py @@ -0,0 +1,43 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 32, 'r0_': 16}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr1': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_sum_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_sum_2(in_ptr0, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp3 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp1 = tmp0.to(tl.int64) + tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) + tmp4 = _tmp3 + tmp2 + _tmp3 = tl.where(r0_mask & xmask, tmp4, _tmp3) + tmp3 = tl.sum(_tmp3, 1)[:, None] + x2 = (xindex % ks1) + x3 = xindex // ks1 + tmp5 = tmp3.to(tl.int32) + tl.store(out_ptr1 + (x2 + x3*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1)))), tmp5, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/oe/coe3eymzavqxhueriqixxh3ge3w5pwfu7v55awscq3q3aebtolpx.py b/SpecForge-ext/cache/compiled_kernels/oe/coe3eymzavqxhueriqixxh3ge3w5pwfu7v55awscq3q3aebtolpx.py new file mode 100644 index 0000000000000000000000000000000000000000..046c82b1dee99a7fd1c738b5358f75d45d252400 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/oe/coe3eymzavqxhueriqixxh3ge3w5pwfu7v55awscq3q3aebtolpx.py @@ -0,0 +1,552 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'in_ptr9': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128*ks1, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128*ks1, 128, 1 + + ZQ = 2 + HQ = 32 + Q_LEN = ks0 + ZKV = 2 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 2 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m + 4096*idx_zq*ks0, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr9 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = ks5 + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/ox/coxfjay6pu7d5jwg5ut6t3b4g7zyrv2xtah2eaxih7w5isj3rgqn.py b/SpecForge-ext/cache/compiled_kernels/ox/coxfjay6pu7d5jwg5ut6t3b4g7zyrv2xtah2eaxih7w5isj3rgqn.py new file mode 100644 index 0000000000000000000000000000000000000000..dec0a4231396f3155b7b407db5cf8371ee4a4aff --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/ox/coxfjay6pu7d5jwg5ut6t3b4g7zyrv2xtah2eaxih7w5isj3rgqn.py @@ -0,0 +1,49 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 64, 'r0_': 16}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr2': '*i32', 'out_ptr3': '*i32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3(in_ptr0, out_ptr2, out_ptr3, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 16 + R0_BLOCK: tl.constexpr = 16 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + x0 = (xindex % ks0) + x1 = xindex // ks0 + x3 = xindex + tmp0 = tl.load(in_ptr0 + (r0_2 + x0 + 16*x1 + ks0*r0_2 + 16*ks0*x1), xmask, eviction_policy='evict_last', other=0.0) + tmp1 = r0_2 + tmp2 = tmp1.to(tl.int16) + tmp3 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp4 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5, tmp6, = triton_helpers.sort_with_index(tmp3, tmp4, None, 1, stable=True, descending=True) + tmp7 = tmp0.to(tl.int64) + tmp8 = tl.broadcast_to(tmp7, [XBLOCK, R0_BLOCK]) + tmp10 = tl.where(xmask, tmp8, 0) + tmp11 = tl.sum(tmp10, 1)[:, None].to(tl.int64) + tmp12 = tmp6.to(tl.int64) + tmp13 = tmp12.to(tl.int32) + tmp14 = tmp11.to(tl.int32) + tl.store(out_ptr2 + (r0_2 + 16*x0 + 16*x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp13, xmask) + tl.store(out_ptr3 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp14, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/p4/cp4ijj7ba3ojx3goskrebxodtrzso4wryuwcis3dhm4ynbtr4x76.py b/SpecForge-ext/cache/compiled_kernels/p4/cp4ijj7ba3ojx3goskrebxodtrzso4wryuwcis3dhm4ynbtr4x76.py new file mode 100644 index 0000000000000000000000000000000000000000..a5d8144c436d549aab61f80101cdd8a56aee7f88 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/p4/cp4ijj7ba3ojx3goskrebxodtrzso4wryuwcis3dhm4ynbtr4x76.py @@ -0,0 +1,46 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 1, 'r0_': 4096}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i64', 'out_ptr1': '*fp32', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {'xnumel': 1}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_clamp_min_div_sum_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_clamp_min_div_sum_3(in_ptr0, in_ptr1, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 1 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + _tmp2 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_0 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp3 = _tmp2 + tmp1 + _tmp2 = tl.where(r0_mask, tmp3, _tmp2) + tmp2 = tl.sum(_tmp2, 1)[:, None] + tmp4 = tl.load(in_ptr1 + (0)) + tmp5 = tl.broadcast_to(tmp4, [XBLOCK, 1]) + tmp6 = tmp5.to(tl.float32) + tmp7 = tmp2.to(tl.float32) + tmp8 = 1e-06 + tmp9 = triton_helpers.maximum(tmp7, tmp8) + tmp10 = (tmp6 / tmp9) + tl.store(out_ptr1 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp10, None) diff --git a/SpecForge-ext/cache/compiled_kernels/p4/cp4vyx3v5hjlm2cw4jlmpuipfj5jbrdnwtpaswbehksw44leomx6.py b/SpecForge-ext/cache/compiled_kernels/p4/cp4vyx3v5hjlm2cw4jlmpuipfj5jbrdnwtpaswbehksw44leomx6.py new file mode 100644 index 0000000000000000000000000000000000000000..f2a65bbd45160b1605715a37c117ff6f0c652e98 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/p4/cp4vyx3v5hjlm2cw4jlmpuipfj5jbrdnwtpaswbehksw44leomx6.py @@ -0,0 +1,49 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 131072, 'r0_': 128}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr1': '*fp32', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_zeros_0(in_ptr0, in_ptr1, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % ks0) + x1 = ((xindex // ks0) % 32) + x2 = xindex // ks1 + x5 = triton_helpers.div_floor_integer(xindex, ks0) + _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + x4 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_3 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_3 + 128*x1 + 4096*x0 + 4096*ks0*x2), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_3 + 128*x0 + 128*x5*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = _tmp4 + tmp3 + _tmp4 = tl.where(r0_mask & xmask, tmp5, _tmp4) + tmp4 = tl.sum(_tmp4, 1)[:, None] + tmp6 = tmp4.to(tl.float32) + tmp7 = 0.0 + tmp8 = tmp6 - tmp7 + tl.store(out_ptr1 + (x4), tmp8, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/p4/e061f130e28a019ee057930c34d1441fdd0628c7b3f491be58a2d571296b99dd.best_config b/SpecForge-ext/cache/compiled_kernels/p4/e061f130e28a019ee057930c34d1441fdd0628c7b3f491be58a2d571296b99dd.best_config new file mode 100644 index 0000000000000000000000000000000000000000..6bd35c457bb44bb5851a69374a1e23cf27a4eff9 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/p4/e061f130e28a019ee057930c34d1441fdd0628c7b3f491be58a2d571296b99dd.best_config @@ -0,0 +1 @@ +{"XBLOCK": 64, "R0_BLOCK": 64, "num_warps": 16, "num_stages": 1, "configs_hash": "48464ea7d171263ae4fed5184e32a30841f1081b8df295ec1f8e2f76e5287c9d", "found_by_coordesc": false, "time_taken_ms": 54, "triton_cache_hash": "BXWZSSWKBTIG7YDOE6QDLF3DYUHLUN57GPEDYW37ZDRQO2XWRGCQ"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/p7/cp7oi5evlluu4tzoolnivejb2h2wxctqdm2h4fyxttvr7dsyw3cu.py b/SpecForge-ext/cache/compiled_kernels/p7/cp7oi5evlluu4tzoolnivejb2h2wxctqdm2h4fyxttvr7dsyw3cu.py new file mode 100644 index 0000000000000000000000000000000000000000..e2620825cad34fcab4be0f0a856d1bcc2285a0bb --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/p7/cp7oi5evlluu4tzoolnivejb2h2wxctqdm2h4fyxttvr7dsyw3cu.py @@ -0,0 +1,24 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 4096}, + filename=__file__, + triton_meta={'signature': {'out_ptr0': '*i32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_new_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_new_zeros_0(out_ptr0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.full([1], 0, tl.int32) + tl.store(out_ptr0 + (x0), tmp0, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/p7/d38ffec5c63d55d460573ff3e1a8241c6d082a97ceb19efeadaa87d04443c2dd.best_config b/SpecForge-ext/cache/compiled_kernels/p7/d38ffec5c63d55d460573ff3e1a8241c6d082a97ceb19efeadaa87d04443c2dd.best_config new file mode 100644 index 0000000000000000000000000000000000000000..96dd92ec5b0239e781a57a11e2928c5c0f286636 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/p7/d38ffec5c63d55d460573ff3e1a8241c6d082a97ceb19efeadaa87d04443c2dd.best_config @@ -0,0 +1 @@ +{"XBLOCK": 256, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 11, "triton_cache_hash": "42NVHDOVRHC3TSIT2M6NVJU72L5EVVTGFXWS47GDCP2GM2XRN7KA"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/q2/cq2dxb6rwep2w3qouz3knh4qwrzgdh742pem6ugipfioxktbqtf6.py b/SpecForge-ext/cache/compiled_kernels/q2/cq2dxb6rwep2w3qouz3knh4qwrzgdh742pem6ugipfioxktbqtf6.py new file mode 100644 index 0000000000000000000000000000000000000000..fcd39dfae200360cb5f35afb62f9ffe5e3b77ae2 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/q2/cq2dxb6rwep2w3qouz3knh4qwrzgdh742pem6ugipfioxktbqtf6.py @@ -0,0 +1,552 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'in_ptr9': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks0, 128*ks0, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks0, 128*ks0, 128, 1 + + ZQ = 8 + HQ = 32 + Q_LEN = 2048 + ZKV = 8 + KV_LEN = ks0 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 8 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 16*ks1 + stride_kv_idx_m = ks1 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 262144*idx_hq + 8388608*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m + 8388608*idx_zq, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr9 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/q7/cq74443yqfb7o5jaukmottfsmin7koocwm65uppitrq7ujnzm2cu.py b/SpecForge-ext/cache/compiled_kernels/q7/cq74443yqfb7o5jaukmottfsmin7koocwm65uppitrq7ujnzm2cu.py new file mode 100644 index 0000000000000000000000000000000000000000..de492a7781aeac6b0c855be0b578a391b16f8971 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/q7/cq74443yqfb7o5jaukmottfsmin7koocwm65uppitrq7ujnzm2cu.py @@ -0,0 +1,1051 @@ +# AOT ID: ['6_backward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/si/csivid7ys23us3bz753ofgfyl6kefmzjfmymnzsvs4zosyg73h6z.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] +# Source node to ATen node mapping: +# Graph fragment: +# %getitem : Tensor "bf16[2, 32, 2048, 128][8388608, 128, 4096, 1]cuda:2" = PlaceHolder[target=getitem] +# %tangents_1 : Tensor "bf16[2, 32, 2048, 128][8388608, 262144, 128, 1]cuda:2" = PlaceHolder[target=tangents_1] +# %buf0 : Tensor "bf16[2, 32, 2048][65536, 2048, 1]cuda:2" = PlaceHolder[target=buf0] +# %full_default : Tensor "f32[2, 32, 2048][65536, 2048, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 32, 2048], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:2, pin_memory: False}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_1, %primals_2, %primals_3, %getitem, %getitem_1, %tangents_1, %full_default, %fw_graph0, %joint_graph0, (2048, 2048, %primals_5, %primals_4, %primals_7, %primals_8, %primals_9, %primals_10, %primals_11, %primals_12, 128, 128, %mask_graph0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_6,)), kwargs = {}) +# return %buf0,%buf1 +triton_red_fused_zeros_0 = async_compile.triton('triton_red_fused_zeros_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 131072, 'r0_': 128}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr1': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 1048576, 'r0_': 67108864}} +) +@triton.jit +def triton_red_fused_zeros_0(in_ptr0, in_ptr1, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 131072 + r0_numel = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % 2048) + x1 = ((xindex // 2048) % 32) + x2 = xindex // 65536 + x4 = xindex + _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_3 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_3 + 128*x1 + 4096*x0 + 8388608*x2), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_3 + 128*x4), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = _tmp4 + tmp3 + _tmp4 = tl.where(r0_mask, tmp5, _tmp4) + tmp4 = tl.sum(_tmp4, 1)[:, None] + tmp6 = tmp4.to(tl.float32) + tmp7 = 0.0 + tmp8 = tmp6 - tmp7 + tl.store(out_ptr1 + (x4), tmp8, None) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/4x/c4xykt7eysbenti5r55drq4w7k6c7fih4ifrou2alyqcn6r5enon.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] +# Source node to ATen node mapping: +# Graph fragment: +# %primals_1 : Tensor "bf16[2, 32, 2048, 128][8388608, 128, 4096, 1]cuda:2" = PlaceHolder[target=primals_1] +# %primals_2 : Tensor "bf16[2, 8, 2048, 128][2097152, 262144, 128, 1]cuda:2" = PlaceHolder[target=primals_2] +# %primals_3 : Tensor "bf16[2, 8, 2048, 128][2097152, 262144, 128, 1]cuda:2" = PlaceHolder[target=primals_3] +# %getitem_1 : Tensor "f32[2, 32, 2048][65536, 2048, 1]cuda:2" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[2, 32, 2048][65536, 2048, 1]cuda:2" = PlaceHolder[target=buf1] +# %tangents_1 : Tensor "bf16[2, 32, 2048, 128][8388608, 262144, 128, 1]cuda:2" = PlaceHolder[target=tangents_1] +# %getitem_3 : Tensor "bf16[2, 32, 2048, 128][8388608, 128, 4096, 1]cuda:2" = PlaceHolder[target=getitem_3] +# %getitem_5 : Tensor "bf16[2, 8, 2048, 128][2097152, 262144, 128, 1]cuda:2" = PlaceHolder[target=getitem_5] +# %primals_5 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:2" = PlaceHolder[target=primals_5] +# %primals_4 : Tensor "i32[2, 1, 16, 16][256, 256, 16, 1]cuda:2" = PlaceHolder[target=primals_4] +# %primals_9 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:2" = PlaceHolder[target=primals_9] +# %primals_10 : Tensor "i32[2, 1, 16, 16][256, 256, 16, 1]cuda:2" = PlaceHolder[target=primals_10] +# %primals_7 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:2" = PlaceHolder[target=primals_7] +# %primals_8 : Tensor "i32[2, 1, 16, 16][256, 256, 16, 1]cuda:2" = PlaceHolder[target=primals_8] +# %primals_11 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:2" = PlaceHolder[target=primals_11] +# %primals_12 : Tensor "i32[2, 1, 16, 16][256, 256, 16, 1]cuda:2" = PlaceHolder[target=primals_12] +# %primals_6 : Tensor "i64[2][1]cuda:2" = PlaceHolder[target=primals_6] +# %full_default : Tensor "f32[2, 32, 2048][65536, 2048, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 32, 2048], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:2, pin_memory: False}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_1, %primals_2, %primals_3, %getitem, %getitem_1, %tangents_1, %full_default, %fw_graph0, %joint_graph0, (2048, 2048, %primals_5, %primals_4, %primals_7, %primals_8, %primals_9, %primals_10, %primals_11, %primals_12, 128, 128, %mask_graph0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_6,)), kwargs = {}) +# return %getitem_4 +triton_tem_fused_zeros_1 = async_compile.triton('triton_tem_fused_zeros_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_zeros_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 2097152, 262144, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 2097152, 262144, 128, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 8388608, 262144, 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 8388608, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 2097152, 262144, 128, 1 + + ZQ = 2 + HQ = 32 + HKV = 8 + Q_LEN = 2048 + ZKV = 2 + KV_LEN = 2048 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 2 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 256 + stride_kv_idx_m = 16 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 16 + stride_q_idx_h = 256 + stride_q_idx_n = 16 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 262144*off_hkv + 2097152*off_zq + tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = 2048 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr16 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp39 = (ds) + grad_scores = tmp39 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = 2048 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp40 = (qkT) + post_mod_scores = tmp40 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp41 = tl.full([1], False, tl.int1) + tmp42 = (m) + tmp43 = (n) + tmp44 = tmp42 >= tmp43 + tmp45 = tmp43.to(tl.int64) + tmp46 = (off_z) + tmp47 = tl.load(in_ptr16 + tmp46) + tmp48 = tmp45 < tmp47 + tmp49 = tmp42.to(tl.int64) + tmp50 = tmp49 < tmp47 + tmp51 = tmp48 & tmp50 + tmp52 = tmp44 & tmp51 + tmp53 = tmp41 | tmp52 + tmp54 = tl.full([1], 2048, tl.int32) + tmp55 = tmp43 >= tmp54 + tmp56 = (tmp43 % tmp54) + tmp57 = tl.full([1], 0, tl.int32) + tmp58 = tmp56 != tmp57 + tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0 + tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0 + tmp61 = tmp59 != tmp60 + tmp62 = tmp58 & tmp61 + tmp63 = tmp56 + tmp54 + tmp64 = tl.where(tmp62, tmp63, tmp56) + tmp65 = tmp64.to(tl.int64) + tmp66 = tmp65 < tmp47 + tmp67 = tmp55 & tmp66 + tmp68 = tmp43 - tmp42 + tmp69 = (tmp68 % tmp54) + tmp70 = tmp69 != tmp57 + tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0 + tmp72 = tmp71 != tmp60 + tmp73 = tmp70 & tmp72 + tmp74 = tmp69 + tmp54 + tmp75 = tl.where(tmp73, tmp74, tmp69) + tmp76 = tmp75 == tmp57 + tmp77 = tmp67 & tmp76 + tmp78 = tmp53 | tmp77 + mask_mod_output = tmp78 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp79 = (dsT) + grad_scores = tmp79 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, getitem, getitem_1, tangents_1 = args + args.clear() + assert_size_stride(primals_1, (2, 32, 2048, 128), (8388608, 128, 4096, 1)) + assert_size_stride(primals_2, (2, 8, 2048, 128), (2097152, 262144, 128, 1)) + assert_size_stride(primals_3, (2, 8, 2048, 128), (2097152, 262144, 128, 1)) + assert_size_stride(primals_4, (2, 1, 16, 16), (256, 256, 16, 1)) + assert_size_stride(primals_5, (2, 1, 16), (16, 16, 1)) + assert_size_stride(primals_6, (2, ), (1, )) + assert_size_stride(primals_7, (2, 1, 16), (16, 16, 1)) + assert_size_stride(primals_8, (2, 1, 16, 16), (256, 256, 16, 1)) + assert_size_stride(primals_9, (2, 1, 16), (16, 16, 1)) + assert_size_stride(primals_10, (2, 1, 16, 16), (256, 256, 16, 1)) + assert_size_stride(primals_11, (2, 1, 16), (16, 16, 1)) + assert_size_stride(primals_12, (2, 1, 16, 16), (256, 256, 16, 1)) + assert_size_stride(getitem, (2, 32, 2048, 128), (8388608, 128, 4096, 1)) + assert_size_stride(getitem_1, (2, 32, 2048), (65536, 2048, 1)) + assert_size_stride(tangents_1, (2, 32, 2048, 128), (8388608, 262144, 128, 1)) + with torch.cuda._DeviceGuard(2): + torch.cuda.set_device(2) + buf1 = empty_strided_cuda((2, 32, 2048), (65536, 2048, 1), torch.float32) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] + stream2 = get_raw_stream(2) + triton_red_fused_zeros_0.run(getitem, tangents_1, buf1, 131072, 128, stream=stream2) + del getitem + buf3 = empty_strided_cuda((2, 32, 2048, 128), (8388608, 128, 4096, 1), torch.bfloat16) + buf4 = empty_strided_cuda((2, 8, 2048, 128), (2097152, 262144, 128, 1), torch.bfloat16) + buf5 = empty_strided_cuda((2, 8, 2048, 128), (2097152, 262144, 128, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] + stream2 = get_raw_stream(2) + triton_tem_fused_zeros_1.run(primals_1, primals_2, primals_3, getitem_1, buf1, tangents_1, buf3, buf4, primals_5, primals_4, primals_9, primals_10, primals_7, primals_8, primals_11, primals_12, primals_6, buf5, 80, 2, 8, stream=stream2) + del buf1 + del getitem_1 + del primals_1 + del primals_10 + del primals_11 + del primals_12 + del primals_2 + del primals_3 + del primals_4 + del primals_5 + del primals_6 + del primals_7 + del primals_8 + del primals_9 + del tangents_1 + return (buf3, buf5, buf4, None, None, None, None, None, None, None, None, None, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = rand_strided((2, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:2', dtype=torch.bfloat16) + primals_2 = rand_strided((2, 8, 2048, 128), (2097152, 262144, 128, 1), device='cuda:2', dtype=torch.bfloat16) + primals_3 = rand_strided((2, 8, 2048, 128), (2097152, 262144, 128, 1), device='cuda:2', dtype=torch.bfloat16) + primals_4 = rand_strided((2, 1, 16, 16), (256, 256, 16, 1), device='cuda:2', dtype=torch.int32) + primals_5 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:2', dtype=torch.int32) + primals_6 = rand_strided((2, ), (1, ), device='cuda:2', dtype=torch.int64) + primals_7 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:2', dtype=torch.int32) + primals_8 = rand_strided((2, 1, 16, 16), (256, 256, 16, 1), device='cuda:2', dtype=torch.int32) + primals_9 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:2', dtype=torch.int32) + primals_10 = rand_strided((2, 1, 16, 16), (256, 256, 16, 1), device='cuda:2', dtype=torch.int32) + primals_11 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:2', dtype=torch.int32) + primals_12 = rand_strided((2, 1, 16, 16), (256, 256, 16, 1), device='cuda:2', dtype=torch.int32) + getitem = rand_strided((2, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:2', dtype=torch.bfloat16) + getitem_1 = rand_strided((2, 32, 2048), (65536, 2048, 1), device='cuda:2', dtype=torch.float32) + tangents_1 = rand_strided((2, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:2', dtype=torch.bfloat16) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, getitem, getitem_1, tangents_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/q7/cq7c3xebpviepwedtiqtx7rbpv5742gsfvx7dkelzicucjjea5h2.py b/SpecForge-ext/cache/compiled_kernels/q7/cq7c3xebpviepwedtiqtx7rbpv5742gsfvx7dkelzicucjjea5h2.py new file mode 100644 index 0000000000000000000000000000000000000000..1b0e23a8837affe2fc917befd76c8fcdc795514e --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/q7/cq7c3xebpviepwedtiqtx7rbpv5742gsfvx7dkelzicucjjea5h2.py @@ -0,0 +1,1051 @@ +# AOT ID: ['6_backward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4jnlpbb32eopbc2caystnepiaizyinwoncir73za7sf3sijadk.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] +# Source node to ATen node mapping: +# Graph fragment: +# %getitem : Tensor "bf16[2, 32, 2048, 128][8388608, 128, 4096, 1]cuda:1" = PlaceHolder[target=getitem] +# %tangents_1 : Tensor "bf16[2, 32, 2048, 128][8388608, 262144, 128, 1]cuda:1" = PlaceHolder[target=tangents_1] +# %buf0 : Tensor "bf16[2, 32, 2048][65536, 2048, 1]cuda:1" = PlaceHolder[target=buf0] +# %full_default : Tensor "f32[2, 32, 2048][65536, 2048, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 32, 2048], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:1, pin_memory: False}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_1, %primals_2, %primals_3, %getitem, %getitem_1, %tangents_1, %full_default, %fw_graph0, %joint_graph0, (2048, 2048, %primals_5, %primals_4, %primals_7, %primals_8, %primals_9, %primals_10, %primals_11, %primals_12, 128, 128, %mask_graph0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_6,)), kwargs = {}) +# return %buf0,%buf1 +triton_red_fused_zeros_0 = async_compile.triton('triton_red_fused_zeros_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 131072, 'r0_': 128}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr1': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 1048576, 'r0_': 67108864}} +) +@triton.jit +def triton_red_fused_zeros_0(in_ptr0, in_ptr1, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 131072 + r0_numel = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % 2048) + x1 = ((xindex // 2048) % 32) + x2 = xindex // 65536 + x4 = xindex + _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_3 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_3 + 128*x1 + 4096*x0 + 8388608*x2), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_3 + 128*x4), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = _tmp4 + tmp3 + _tmp4 = tl.where(r0_mask, tmp5, _tmp4) + tmp4 = tl.sum(_tmp4, 1)[:, None] + tmp6 = tmp4.to(tl.float32) + tmp7 = 0.0 + tmp8 = tmp6 - tmp7 + tl.store(out_ptr1 + (x4), tmp8, None) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ck/ccksctmn7dvdwu27cnkgisk777nnrd2keej2xmbgrfhagxbfw7z5.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] +# Source node to ATen node mapping: +# Graph fragment: +# %primals_1 : Tensor "bf16[2, 32, 2048, 128][8388608, 128, 4096, 1]cuda:1" = PlaceHolder[target=primals_1] +# %primals_2 : Tensor "bf16[2, 8, 2048, 128][2097152, 262144, 128, 1]cuda:1" = PlaceHolder[target=primals_2] +# %primals_3 : Tensor "bf16[2, 8, 2048, 128][2097152, 262144, 128, 1]cuda:1" = PlaceHolder[target=primals_3] +# %getitem_1 : Tensor "f32[2, 32, 2048][65536, 2048, 1]cuda:1" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[2, 32, 2048][65536, 2048, 1]cuda:1" = PlaceHolder[target=buf1] +# %tangents_1 : Tensor "bf16[2, 32, 2048, 128][8388608, 262144, 128, 1]cuda:1" = PlaceHolder[target=tangents_1] +# %getitem_3 : Tensor "bf16[2, 32, 2048, 128][8388608, 128, 4096, 1]cuda:1" = PlaceHolder[target=getitem_3] +# %getitem_5 : Tensor "bf16[2, 8, 2048, 128][2097152, 262144, 128, 1]cuda:1" = PlaceHolder[target=getitem_5] +# %primals_5 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:1" = PlaceHolder[target=primals_5] +# %primals_4 : Tensor "i32[2, 1, 16, 16][256, 256, 16, 1]cuda:1" = PlaceHolder[target=primals_4] +# %primals_9 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:1" = PlaceHolder[target=primals_9] +# %primals_10 : Tensor "i32[2, 1, 16, 16][256, 256, 16, 1]cuda:1" = PlaceHolder[target=primals_10] +# %primals_7 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:1" = PlaceHolder[target=primals_7] +# %primals_8 : Tensor "i32[2, 1, 16, 16][256, 256, 16, 1]cuda:1" = PlaceHolder[target=primals_8] +# %primals_11 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:1" = PlaceHolder[target=primals_11] +# %primals_12 : Tensor "i32[2, 1, 16, 16][256, 256, 16, 1]cuda:1" = PlaceHolder[target=primals_12] +# %primals_6 : Tensor "i64[2][1]cuda:1" = PlaceHolder[target=primals_6] +# %full_default : Tensor "f32[2, 32, 2048][65536, 2048, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 32, 2048], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:1, pin_memory: False}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_1, %primals_2, %primals_3, %getitem, %getitem_1, %tangents_1, %full_default, %fw_graph0, %joint_graph0, (2048, 2048, %primals_5, %primals_4, %primals_7, %primals_8, %primals_9, %primals_10, %primals_11, %primals_12, 128, 128, %mask_graph0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_6,)), kwargs = {}) +# return %getitem_4 +triton_tem_fused_zeros_1 = async_compile.triton('triton_tem_fused_zeros_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_zeros_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 2097152, 262144, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 2097152, 262144, 128, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 8388608, 262144, 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 8388608, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 2097152, 262144, 128, 1 + + ZQ = 2 + HQ = 32 + HKV = 8 + Q_LEN = 2048 + ZKV = 2 + KV_LEN = 2048 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 2 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 256 + stride_kv_idx_m = 16 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 16 + stride_q_idx_h = 256 + stride_q_idx_n = 16 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 262144*off_hkv + 2097152*off_zq + tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = 2048 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr16 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp39 = (ds) + grad_scores = tmp39 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = 2048 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp40 = (qkT) + post_mod_scores = tmp40 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp41 = tl.full([1], False, tl.int1) + tmp42 = (m) + tmp43 = (n) + tmp44 = tmp42 >= tmp43 + tmp45 = tmp43.to(tl.int64) + tmp46 = (off_z) + tmp47 = tl.load(in_ptr16 + tmp46) + tmp48 = tmp45 < tmp47 + tmp49 = tmp42.to(tl.int64) + tmp50 = tmp49 < tmp47 + tmp51 = tmp48 & tmp50 + tmp52 = tmp44 & tmp51 + tmp53 = tmp41 | tmp52 + tmp54 = tl.full([1], 2048, tl.int32) + tmp55 = tmp43 >= tmp54 + tmp56 = (tmp43 % tmp54) + tmp57 = tl.full([1], 0, tl.int32) + tmp58 = tmp56 != tmp57 + tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0 + tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0 + tmp61 = tmp59 != tmp60 + tmp62 = tmp58 & tmp61 + tmp63 = tmp56 + tmp54 + tmp64 = tl.where(tmp62, tmp63, tmp56) + tmp65 = tmp64.to(tl.int64) + tmp66 = tmp65 < tmp47 + tmp67 = tmp55 & tmp66 + tmp68 = tmp43 - tmp42 + tmp69 = (tmp68 % tmp54) + tmp70 = tmp69 != tmp57 + tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0 + tmp72 = tmp71 != tmp60 + tmp73 = tmp70 & tmp72 + tmp74 = tmp69 + tmp54 + tmp75 = tl.where(tmp73, tmp74, tmp69) + tmp76 = tmp75 == tmp57 + tmp77 = tmp67 & tmp76 + tmp78 = tmp53 | tmp77 + mask_mod_output = tmp78 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp79 = (dsT) + grad_scores = tmp79 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, getitem, getitem_1, tangents_1 = args + args.clear() + assert_size_stride(primals_1, (2, 32, 2048, 128), (8388608, 128, 4096, 1)) + assert_size_stride(primals_2, (2, 8, 2048, 128), (2097152, 262144, 128, 1)) + assert_size_stride(primals_3, (2, 8, 2048, 128), (2097152, 262144, 128, 1)) + assert_size_stride(primals_4, (2, 1, 16, 16), (256, 256, 16, 1)) + assert_size_stride(primals_5, (2, 1, 16), (16, 16, 1)) + assert_size_stride(primals_6, (2, ), (1, )) + assert_size_stride(primals_7, (2, 1, 16), (16, 16, 1)) + assert_size_stride(primals_8, (2, 1, 16, 16), (256, 256, 16, 1)) + assert_size_stride(primals_9, (2, 1, 16), (16, 16, 1)) + assert_size_stride(primals_10, (2, 1, 16, 16), (256, 256, 16, 1)) + assert_size_stride(primals_11, (2, 1, 16), (16, 16, 1)) + assert_size_stride(primals_12, (2, 1, 16, 16), (256, 256, 16, 1)) + assert_size_stride(getitem, (2, 32, 2048, 128), (8388608, 128, 4096, 1)) + assert_size_stride(getitem_1, (2, 32, 2048), (65536, 2048, 1)) + assert_size_stride(tangents_1, (2, 32, 2048, 128), (8388608, 262144, 128, 1)) + with torch.cuda._DeviceGuard(1): + torch.cuda.set_device(1) + buf1 = empty_strided_cuda((2, 32, 2048), (65536, 2048, 1), torch.float32) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] + stream1 = get_raw_stream(1) + triton_red_fused_zeros_0.run(getitem, tangents_1, buf1, 131072, 128, stream=stream1) + del getitem + buf3 = empty_strided_cuda((2, 32, 2048, 128), (8388608, 128, 4096, 1), torch.bfloat16) + buf4 = empty_strided_cuda((2, 8, 2048, 128), (2097152, 262144, 128, 1), torch.bfloat16) + buf5 = empty_strided_cuda((2, 8, 2048, 128), (2097152, 262144, 128, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] + stream1 = get_raw_stream(1) + triton_tem_fused_zeros_1.run(primals_1, primals_2, primals_3, getitem_1, buf1, tangents_1, buf3, buf4, primals_5, primals_4, primals_9, primals_10, primals_7, primals_8, primals_11, primals_12, primals_6, buf5, 80, 2, 8, stream=stream1) + del buf1 + del getitem_1 + del primals_1 + del primals_10 + del primals_11 + del primals_12 + del primals_2 + del primals_3 + del primals_4 + del primals_5 + del primals_6 + del primals_7 + del primals_8 + del primals_9 + del tangents_1 + return (buf3, buf5, buf4, None, None, None, None, None, None, None, None, None, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = rand_strided((2, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:1', dtype=torch.bfloat16) + primals_2 = rand_strided((2, 8, 2048, 128), (2097152, 262144, 128, 1), device='cuda:1', dtype=torch.bfloat16) + primals_3 = rand_strided((2, 8, 2048, 128), (2097152, 262144, 128, 1), device='cuda:1', dtype=torch.bfloat16) + primals_4 = rand_strided((2, 1, 16, 16), (256, 256, 16, 1), device='cuda:1', dtype=torch.int32) + primals_5 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:1', dtype=torch.int32) + primals_6 = rand_strided((2, ), (1, ), device='cuda:1', dtype=torch.int64) + primals_7 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:1', dtype=torch.int32) + primals_8 = rand_strided((2, 1, 16, 16), (256, 256, 16, 1), device='cuda:1', dtype=torch.int32) + primals_9 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:1', dtype=torch.int32) + primals_10 = rand_strided((2, 1, 16, 16), (256, 256, 16, 1), device='cuda:1', dtype=torch.int32) + primals_11 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:1', dtype=torch.int32) + primals_12 = rand_strided((2, 1, 16, 16), (256, 256, 16, 1), device='cuda:1', dtype=torch.int32) + getitem = rand_strided((2, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:1', dtype=torch.bfloat16) + getitem_1 = rand_strided((2, 32, 2048), (65536, 2048, 1), device='cuda:1', dtype=torch.float32) + tangents_1 = rand_strided((2, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:1', dtype=torch.bfloat16) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, getitem, getitem_1, tangents_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/q7/cq7ed3ned3bl7umfg6eialakjdwf7rhno4meayw34jakkuz6bga5.py b/SpecForge-ext/cache/compiled_kernels/q7/cq7ed3ned3bl7umfg6eialakjdwf7rhno4meayw34jakkuz6bga5.py new file mode 100644 index 0000000000000000000000000000000000000000..53e93ea32c52d56de7402e4d8fb926380820e3ad --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/q7/cq7ed3ned3bl7umfg6eialakjdwf7rhno4meayw34jakkuz6bga5.py @@ -0,0 +1,72 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 512, 'r0_': 16384}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 8192, 'r0_': 0}} +) +@triton.jit +def triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 512 + r0_numel = 16384 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x1 = ((xindex // 16) % 16) + x0 = (xindex % 16) + x2 = xindex // 256 + tmp3 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last') + _tmp29 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + x6 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_4 = r0_index // 128 + r0_3 = (r0_index % 128) + tmp0 = r0_4 + 128*x1 + tmp1 = r0_3 + 128*x0 + tmp2 = tmp0 >= tmp1 + tmp4 = tmp1 < tmp3 + tmp5 = tmp0 < tmp3 + tmp6 = tmp4 & tmp5 + tmp7 = tmp2 & tmp6 + tmp8 = tl.full([1, 1], False, tl.int1) + tmp9 = tmp8 | tmp7 + tmp10 = tl.full([1, 1], 2048, tl.int64) + tmp11 = tmp1 >= tmp10 + tmp12 = tmp11 & tmp4 + tmp13 = r0_3 + ((-1)*r0_4) + ((-128)*x1) + 128*x0 + tmp14 = (tmp13 % tmp10) + tmp15 = tl.full([1, 1], 0, tl.int32) + tmp16 = tmp14 != tmp15 + tmp17 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp18 = (libdevice.signbit(tmp10) != 0) if (tmp10).dtype is tl.float32 else tmp10 < 0 + tmp19 = tmp17 != tmp18 + tmp20 = tmp16 & tmp19 + tmp21 = tmp14 + tmp10 + tmp22 = tl.where(tmp20, tmp21, tmp14) + tmp23 = tl.full([1, 1], 0, tl.int64) + tmp24 = tmp22 == tmp23 + tmp25 = tmp12 & tmp24 + tmp26 = tmp9 | tmp25 + tmp27 = tmp26.to(tl.int64) + tmp28 = tl.broadcast_to(tmp27, [XBLOCK, R0_BLOCK]) + tmp30 = _tmp29 + tmp28 + _tmp29 = tl.where(r0_mask & xmask, tmp30, _tmp29) + tmp29 = tl.sum(_tmp29, 1)[:, None] + tl.store(out_ptr0 + (x6), tmp29, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/q7/cq7qlfw752z7qztzmvc5idn6ikz7c7qhxrmppou7flmcrl63nd7h.py b/SpecForge-ext/cache/compiled_kernels/q7/cq7qlfw752z7qztzmvc5idn6ikz7c7qhxrmppou7flmcrl63nd7h.py new file mode 100644 index 0000000000000000000000000000000000000000..16312264df7fcdcdd467ef7b2c421e7499781d26 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/q7/cq7qlfw752z7qztzmvc5idn6ikz7c7qhxrmppou7flmcrl63nd7h.py @@ -0,0 +1,41 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 2, 'r0_': 8192}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_sum_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 8, 'r0_': 131072}} +) +@triton.jit +def triton_red_fused_sum_3(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 2 + r0_numel = 8192 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp2 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 8192*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp3 = _tmp2 + tmp1 + _tmp2 = tl.where(r0_mask & xmask, tmp3, _tmp2) + tmp2 = tl.sum(_tmp2, 1)[:, None] + tl.store(out_ptr0 + (x0), tmp2, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/qe/9b50f336d234dc5c123867787bd9eac8b8452d694f67f090b8cc1b62587238f9.best_config b/SpecForge-ext/cache/compiled_kernels/qe/9b50f336d234dc5c123867787bd9eac8b8452d694f67f090b8cc1b62587238f9.best_config new file mode 100644 index 0000000000000000000000000000000000000000..480196fb4e04fbb2c822576a3fd81e7866bc88fc --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/qe/9b50f336d234dc5c123867787bd9eac8b8452d694f67f090b8cc1b62587238f9.best_config @@ -0,0 +1 @@ +{"XBLOCK": 1, "num_warps": 2, "num_stages": 1, "configs_hash": "b6ac5ef64fddcad8fc8d2c05fa12424871fd9baa5a4158ff38ecebbafb55a4b1", "found_by_coordesc": false, "time_taken_ms": 26, "triton_cache_hash": "E2MI47QNGZ2SJDA3U3EKHN7H3EYRAANF6T7N5SFT2CZJYNBAWCNQ"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/qe/cqeqbkboqdi3uxaw5a5tmxlncw3qggxtaklqqgfqju7a73hndnau.py b/SpecForge-ext/cache/compiled_kernels/qe/cqeqbkboqdi3uxaw5a5tmxlncw3qggxtaklqqgfqju7a73hndnau.py new file mode 100644 index 0000000000000000000000000000000000000000..d9371a6389b344042e748339abfb700e1a66769a --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/qe/cqeqbkboqdi3uxaw5a5tmxlncw3qggxtaklqqgfqju7a73hndnau.py @@ -0,0 +1,50 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 128, 'r0_': 16}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr2': '*i32', 'out_ptr3': '*i32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 1024, 'r0_': 16384}} +) +@triton.jit +def triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3(in_ptr0, out_ptr2, out_ptr3, xnumel, r0_numel, XBLOCK : tl.constexpr): + xnumel = 128 + r0_numel = 16 + R0_BLOCK: tl.constexpr = 16 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + x0 = (xindex % 16) + x1 = xindex // 16 + x3 = xindex + tmp0 = tl.load(in_ptr0 + (x0 + 17*r0_2 + 272*x1), xmask, other=0.0) + tmp1 = r0_2 + tmp2 = tmp1.to(tl.int16) + tmp3 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp4 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5, tmp6, = triton_helpers.sort_with_index(tmp3, tmp4, None, 1, stable=True, descending=True) + tmp7 = tmp0.to(tl.int64) + tmp8 = tl.broadcast_to(tmp7, [XBLOCK, R0_BLOCK]) + tmp10 = tl.where(xmask, tmp8, 0) + tmp11 = tl.sum(tmp10, 1)[:, None].to(tl.int64) + tmp12 = tmp6.to(tl.int64) + tmp13 = tmp12.to(tl.int32) + tmp14 = tmp11.to(tl.int32) + tl.store(out_ptr2 + (r0_2 + 16*x3), tmp13, xmask) + tl.store(out_ptr3 + (x3), tmp14, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/qg/cqgwrpl43yqzbguusw6gdkhxmikhgvadzh4xygztlm3tab3e7fez.py b/SpecForge-ext/cache/compiled_kernels/qg/cqgwrpl43yqzbguusw6gdkhxmikhgvadzh4xygztlm3tab3e7fez.py new file mode 100644 index 0000000000000000000000000000000000000000..5bbaf71ae2c0715e45604bcc0ad516a09cf00837 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/qg/cqgwrpl43yqzbguusw6gdkhxmikhgvadzh4xygztlm3tab3e7fez.py @@ -0,0 +1,552 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'in_ptr9': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128*ks1, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128*ks1, 128, 1 + + ZQ = 2 + HQ = 32 + Q_LEN = ks0 + ZKV = 2 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 2 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m + 4096*idx_zq*ks0, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr9 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = ks5 + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/qj/cqj5277ktaoo5rg4kvnn7pm72cbfiwp7hxewmxzj4aevxoorlebn.py b/SpecForge-ext/cache/compiled_kernels/qj/cqj5277ktaoo5rg4kvnn7pm72cbfiwp7hxewmxzj4aevxoorlebn.py new file mode 100644 index 0000000000000000000000000000000000000000..6312633083e9cdd91385cefed4dfde229e6758da --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/qj/cqj5277ktaoo5rg4kvnn7pm72cbfiwp7hxewmxzj4aevxoorlebn.py @@ -0,0 +1,47 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*i64', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_argmax_1(in_ptr0, out_ptr0, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % ks0) + x1 = xindex // ks0 + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + x3 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_2 + 32000*x0 + ks1*x1), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask & xmask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask & xmask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tl.store(out_ptr0 + (x3), tmp2, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/qj/cqj5av3bmaggdkactt4p4mlidpqfpoiqrxgqkz5brnzcofnxabkj.py b/SpecForge-ext/cache/compiled_kernels/qj/cqj5av3bmaggdkactt4p4mlidpqfpoiqrxgqkz5brnzcofnxabkj.py new file mode 100644 index 0000000000000000000000000000000000000000..dc6894d09d59effad50c9cf705ea23ce58769241 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/qj/cqj5av3bmaggdkactt4p4mlidpqfpoiqrxgqkz5brnzcofnxabkj.py @@ -0,0 +1,835 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention_backward(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 2097152, 262144, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 2097152, 262144, 128, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 8388608, 262144, 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 8388608, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 2097152, 262144, 128, 1 + + ZQ = 8 + HQ = 32 + HKV = 8 + Q_LEN = 2048 + ZKV = 8 + KV_LEN = 2048 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 8 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 256 + stride_kv_idx_m = 16 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 16 + stride_q_idx_h = 256 + stride_q_idx_n = 16 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 262144*off_hkv + 2097152*off_zq + tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = 2048 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr16 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp39 = (ds) + grad_scores = tmp39 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = 2048 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp40 = (qkT) + post_mod_scores = tmp40 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp41 = tl.full([1], False, tl.int1) + tmp42 = (m) + tmp43 = (n) + tmp44 = tmp42 >= tmp43 + tmp45 = tmp43.to(tl.int64) + tmp46 = (off_z) + tmp47 = tl.load(in_ptr16 + tmp46) + tmp48 = tmp45 < tmp47 + tmp49 = tmp42.to(tl.int64) + tmp50 = tmp49 < tmp47 + tmp51 = tmp48 & tmp50 + tmp52 = tmp44 & tmp51 + tmp53 = tmp41 | tmp52 + tmp54 = tl.full([1], 2048, tl.int32) + tmp55 = tmp43 >= tmp54 + tmp56 = (tmp43 % tmp54) + tmp57 = tl.full([1], 0, tl.int32) + tmp58 = tmp56 != tmp57 + tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0 + tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0 + tmp61 = tmp59 != tmp60 + tmp62 = tmp58 & tmp61 + tmp63 = tmp56 + tmp54 + tmp64 = tl.where(tmp62, tmp63, tmp56) + tmp65 = tmp64.to(tl.int64) + tmp66 = tmp65 < tmp47 + tmp67 = tmp55 & tmp66 + tmp68 = tmp43 - tmp42 + tmp69 = (tmp68 % tmp54) + tmp70 = tmp69 != tmp57 + tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0 + tmp72 = tmp71 != tmp60 + tmp73 = tmp70 & tmp72 + tmp74 = tmp69 + tmp54 + tmp75 = tl.where(tmp73, tmp74, tmp69) + tmp76 = tmp75 == tmp57 + tmp77 = tmp67 & tmp76 + tmp78 = tmp53 | tmp77 + mask_mod_output = tmp78 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp79 = (dsT) + grad_scores = tmp79 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/qo/0c1b4898c1e9c4588fc52cab4df1cf61d9d8dec58397105af08f6a95612a7479.best_config b/SpecForge-ext/cache/compiled_kernels/qo/0c1b4898c1e9c4588fc52cab4df1cf61d9d8dec58397105af08f6a95612a7479.best_config new file mode 100644 index 0000000000000000000000000000000000000000..ed4bbafec32134c55e06add8fdbae259cebe3543 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/qo/0c1b4898c1e9c4588fc52cab4df1cf61d9d8dec58397105af08f6a95612a7479.best_config @@ -0,0 +1 @@ +{"XBLOCK": 1, "num_warps": 2, "num_stages": 1, "configs_hash": "6fcabd0411a839b7b5d117b5e6638bd1b5d7bc3379312c678d803859f08278a9", "found_by_coordesc": false, "time_taken_ms": 18, "triton_cache_hash": "EB4J5U2HKNQBLXRWK6B5L6ATOH55AWD3MB7P63KH5AKRGRDZER7A"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/qo/cqocxh7pxkxochfj6mpmo7kr52pzynovdougke3x7nmejyge2pwr.py b/SpecForge-ext/cache/compiled_kernels/qo/cqocxh7pxkxochfj6mpmo7kr52pzynovdougke3x7nmejyge2pwr.py new file mode 100644 index 0000000000000000000000000000000000000000..78cd5c713201f0806442b8dc5f343afe67afdce7 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/qo/cqocxh7pxkxochfj6mpmo7kr52pzynovdougke3x7nmejyge2pwr.py @@ -0,0 +1,835 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_zeros_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks0, 128*ks0, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks0, 128*ks0, 128, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 8388608, 262144, 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 8388608, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks0, 128*ks0, 128, 1 + + ZQ = 8 + HQ = 32 + HKV = 8 + Q_LEN = 2048 + ZKV = 8 + KV_LEN = ks0 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 8 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 16*ks1 + stride_kv_idx_m = ks1 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = ks2 + stride_q_idx_h = 16*ks3 + stride_q_idx_n = 16 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks0 + 1024*off_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = ks0 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr16 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp39 = (ds) + grad_scores = tmp39 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = ks0 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp40 = (qkT) + post_mod_scores = tmp40 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp41 = tl.full([1], False, tl.int1) + tmp42 = (m) + tmp43 = (n) + tmp44 = tmp42 >= tmp43 + tmp45 = tmp43.to(tl.int64) + tmp46 = (off_z) + tmp47 = tl.load(in_ptr16 + tmp46) + tmp48 = tmp45 < tmp47 + tmp49 = tmp42.to(tl.int64) + tmp50 = tmp49 < tmp47 + tmp51 = tmp48 & tmp50 + tmp52 = tmp44 & tmp51 + tmp53 = tmp41 | tmp52 + tmp54 = tl.full([1], 2048, tl.int32) + tmp55 = tmp43 >= tmp54 + tmp56 = (tmp43 % tmp54) + tmp57 = tl.full([1], 0, tl.int32) + tmp58 = tmp56 != tmp57 + tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0 + tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0 + tmp61 = tmp59 != tmp60 + tmp62 = tmp58 & tmp61 + tmp63 = tmp56 + tmp54 + tmp64 = tl.where(tmp62, tmp63, tmp56) + tmp65 = tmp64.to(tl.int64) + tmp66 = tmp65 < tmp47 + tmp67 = tmp55 & tmp66 + tmp68 = tmp43 - tmp42 + tmp69 = (tmp68 % tmp54) + tmp70 = tmp69 != tmp57 + tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0 + tmp72 = tmp71 != tmp60 + tmp73 = tmp70 & tmp72 + tmp74 = tmp69 + tmp54 + tmp75 = tl.where(tmp73, tmp74, tmp69) + tmp76 = tmp75 == tmp57 + tmp77 = tmp67 & tmp76 + tmp78 = tmp53 | tmp77 + mask_mod_output = tmp78 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp79 = (dsT) + grad_scores = tmp79 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) diff --git a/SpecForge-ext/cache/compiled_kernels/qo/cqojmb5e4b5iomuis3bstfp3rn23xoq2xegyr72zgkrjnzktbusv.py b/SpecForge-ext/cache/compiled_kernels/qo/cqojmb5e4b5iomuis3bstfp3rn23xoq2xegyr72zgkrjnzktbusv.py new file mode 100644 index 0000000000000000000000000000000000000000..00a46f27b07cb80fbd67c2a4012493cd557152d3 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/qo/cqojmb5e4b5iomuis3bstfp3rn23xoq2xegyr72zgkrjnzktbusv.py @@ -0,0 +1,50 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 32, 'r0_': 16}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr2': '*i32', 'out_ptr3': '*i32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 256, 'r0_': 4096}} +) +@triton.jit +def triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3(in_ptr0, out_ptr2, out_ptr3, xnumel, r0_numel, XBLOCK : tl.constexpr): + xnumel = 32 + r0_numel = 16 + R0_BLOCK: tl.constexpr = 16 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + x0 = (xindex % 16) + x1 = xindex // 16 + x3 = xindex + tmp0 = tl.load(in_ptr0 + (x0 + 17*r0_2 + 272*x1), xmask, other=0.0) + tmp1 = r0_2 + tmp2 = tmp1.to(tl.int16) + tmp3 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp4 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5, tmp6, = triton_helpers.sort_with_index(tmp3, tmp4, None, 1, stable=True, descending=True) + tmp7 = tmp0.to(tl.int64) + tmp8 = tl.broadcast_to(tmp7, [XBLOCK, R0_BLOCK]) + tmp10 = tl.where(xmask, tmp8, 0) + tmp11 = tl.sum(tmp10, 1)[:, None].to(tl.int64) + tmp12 = tmp6.to(tl.int64) + tmp13 = tmp12.to(tl.int32) + tmp14 = tmp11.to(tl.int32) + tl.store(out_ptr2 + (r0_2 + 16*x3), tmp13, xmask) + tl.store(out_ptr3 + (x3), tmp14, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/qr/47448a137acc86f31ba0b62156af5ec19302322620b71e04aa7855e144d48d75.best_config b/SpecForge-ext/cache/compiled_kernels/qr/47448a137acc86f31ba0b62156af5ec19302322620b71e04aa7855e144d48d75.best_config new file mode 100644 index 0000000000000000000000000000000000000000..9b080f3c548afd0cf9435a6d43c13b816a92a36b --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/qr/47448a137acc86f31ba0b62156af5ec19302322620b71e04aa7855e144d48d75.best_config @@ -0,0 +1 @@ +{"XBLOCK": 128, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 11, "triton_cache_hash": "INOFCMBF4AOGTUSNRPBLV7E37E4P43AGG4323SKXUALONOEWOJUA"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/qr/cqrok7demcrbt3yh6rmj2bttpopxhcu4l237sc3deioytctzun6e.py b/SpecForge-ext/cache/compiled_kernels/qr/cqrok7demcrbt3yh6rmj2bttpopxhcu4l237sc3deioytctzun6e.py new file mode 100644 index 0000000000000000000000000000000000000000..1cedc4650d7fd1ea0e9392448c4e93340c319f43 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/qr/cqrok7demcrbt3yh6rmj2bttpopxhcu4l237sc3deioytctzun6e.py @@ -0,0 +1,25 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 4096}, + filename=__file__, + triton_meta={'signature': {'out_ptr0': '*i32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_new_zeros_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 17408}}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_new_zeros_1(out_ptr0, xnumel, XBLOCK : tl.constexpr): + xnumel = 2176 + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.full([1], 0, tl.int32) + tl.store(out_ptr0 + (x0), tmp0, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/qw/cqwiskd5ym44ryibwdba26jjqbg7sj2ifv7xrjccdb4wjtrolvjz.py b/SpecForge-ext/cache/compiled_kernels/qw/cqwiskd5ym44ryibwdba26jjqbg7sj2ifv7xrjccdb4wjtrolvjz.py new file mode 100644 index 0000000000000000000000000000000000000000..6992b995815bab118ce6d09799824d8b4a683226 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/qw/cqwiskd5ym44ryibwdba26jjqbg7sj2ifv7xrjccdb4wjtrolvjz.py @@ -0,0 +1,164 @@ +# AOT ID: ['0_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ml/cmlgrfrpyvl5gqlphdfiqxpqawlo6wyjtpwqkky6zywkkdk4h4hl.py +# Topologically Sorted Source Nodes: [target_max_token, target_mask, getitem_1, target_mask_1, position_mask], Original ATen: [aten.argmax, aten.index, aten.unsqueeze, aten._to_copy, aten.mul] +# Source node to ATen node mapping: +# getitem_1 => unsqueeze +# position_mask => mul +# target_mask => index +# target_mask_1 => convert_element_type +# target_max_token => argmax +# Graph fragment: +# %arg0_1 : Tensor "bf16[2, 2048, 151936][311164928, 151936, 1]cuda:6" = PlaceHolder[target=arg0_1] +# %argmax : Tensor "i64[2, 2048][2048, 1]cuda:6" = PlaceHolder[target=argmax] +# %arg1_1 : Tensor "b8[151936][1]cuda:6" = PlaceHolder[target=arg1_1] +# %arg2_1 : Tensor "i64[2, 2048, 1][2048, 1, 1]cuda:6" = PlaceHolder[target=arg2_1] +# %argmax : Tensor "i64[2, 2048][2048, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.argmax.default](args = (%arg0_1, -1), kwargs = {}) +# %index : Tensor "b8[2, 2048][2048, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg1_1, [%argmax]), kwargs = {}) +# %unsqueeze : Tensor "b8[2, 2048, 1][2048, 1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index, 2), kwargs = {}) +# %convert_element_type : Tensor "i32[2, 2048, 1][2048, 1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%unsqueeze, torch.int32), kwargs = {}) +# %mul : Tensor "i64[2, 2048, 1][2048, 1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type, %arg2_1), kwargs = {}) +# return %argmax,%mul +triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0 = async_compile.triton('triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 262144}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_out_ptr0': '*i64', 'in_ptr0': '*bf16', 'in_ptr1': '*i1', 'in_ptr2': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 4096 + r0_numel = 151936 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 151936*x0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tmp11 = tl.load(in_ptr2 + (x0), None, eviction_policy='evict_last') + tmp3 = tl.full([XBLOCK, 1], 151936, tl.int32) + tmp4 = tmp2 + tmp3 + tmp5 = tmp2 < 0 + tmp6 = tl.where(tmp5, tmp4, tmp2) + tl.device_assert((0 <= tmp6) & (tmp6 < 151936), "index out of bounds: 0 <= tmp6 < 151936") + tmp8 = tl.load(in_ptr1 + (tmp6), None, eviction_policy='evict_last').to(tl.int1) + tmp9 = tmp8.to(tl.int32) + tmp10 = tmp9.to(tl.int64) + tmp12 = tmp10 * tmp11 + tl.debug_barrier() + tl.store(in_out_ptr0 + (x0), tmp12, None) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1, arg2_1 = args + args.clear() + assert_size_stride(arg0_1, (2, 2048, 151936), (311164928, 151936, 1)) + assert_size_stride(arg1_1, (151936, ), (1, )) + assert_size_stride(arg2_1, (2, 2048, 1), (2048, 1, 1)) + with torch.cuda._DeviceGuard(6): + torch.cuda.set_device(6) + buf0 = empty_strided_cuda((2, 2048), (2048, 1), torch.int64) + buf1 = reinterpret_tensor(buf0, (2, 2048, 1), (2048, 1, 1), 0); del buf0 # reuse + # Topologically Sorted Source Nodes: [target_max_token, target_mask, getitem_1, target_mask_1, position_mask], Original ATen: [aten.argmax, aten.index, aten.unsqueeze, aten._to_copy, aten.mul] + stream6 = get_raw_stream(6) + triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0.run(buf1, arg0_1, arg1_1, arg2_1, 4096, 151936, stream=stream6) + del arg0_1 + del arg1_1 + del arg2_1 + return (buf1, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = rand_strided((2, 2048, 151936), (311164928, 151936, 1), device='cuda:6', dtype=torch.bfloat16) + arg1_1 = rand_strided((151936, ), (1, ), device='cuda:6', dtype=torch.bool) + arg2_1 = rand_strided((2, 2048, 1), (2048, 1, 1), device='cuda:6', dtype=torch.int64) + fn = lambda: call([arg0_1, arg1_1, arg2_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/re/c41d9bfeac257cd8a9bf6081d91819762871bee2bda93eaf935dfbe63cb123f4.best_config b/SpecForge-ext/cache/compiled_kernels/re/c41d9bfeac257cd8a9bf6081d91819762871bee2bda93eaf935dfbe63cb123f4.best_config new file mode 100644 index 0000000000000000000000000000000000000000..9fcd69eff532e4347a9de460be081b110cdc50c2 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/re/c41d9bfeac257cd8a9bf6081d91819762871bee2bda93eaf935dfbe63cb123f4.best_config @@ -0,0 +1 @@ +{"XBLOCK": 512, "num_warps": 8, "num_stages": 1, "configs_hash": "3ca5c3e34d35093f3c9ab2829a9faeebad5e61c4ca13d5ed6053d7b71ce60d5a", "found_by_coordesc": false, "time_taken_ms": 49, "triton_cache_hash": "NFABHOURJ57C2IKXWDMS2VHZ76PCVKJVD7V6CBWJDLMT5TQE5GFA"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/re/creb6srffs2jup5j23uclgnmakvbjubbby54cykdyekpdtqxsfrm.py b/SpecForge-ext/cache/compiled_kernels/re/creb6srffs2jup5j23uclgnmakvbjubbby54cykdyekpdtqxsfrm.py new file mode 100644 index 0000000000000000000000000000000000000000..fb99e4a54f863648141ee8d0eaf883231589662d --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/re/creb6srffs2jup5j23uclgnmakvbjubbby54cykdyekpdtqxsfrm.py @@ -0,0 +1,24 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 8192}, + filename=__file__, + triton_meta={'signature': {'out_ptr0': '*i32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_new_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_new_zeros_0(out_ptr0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.full([1], 0, tl.int32) + tl.store(out_ptr0 + (x0), tmp0, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/re/creieqkebcgqszi536wf6onykyyucm3256hnqdandhgvc2blonfo.py b/SpecForge-ext/cache/compiled_kernels/re/creieqkebcgqszi536wf6onykyyucm3256hnqdandhgvc2blonfo.py new file mode 100644 index 0000000000000000000000000000000000000000..f3d7f9b6afb9aee2a15339f7b26fd2e3fc9cecb4 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/re/creieqkebcgqszi536wf6onykyyucm3256hnqdandhgvc2blonfo.py @@ -0,0 +1,56 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 16777216}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'ks4': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, ks4, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x4 = xindex + x2 = ((xindex // ks0) % ks1) + x0 = (xindex % ks3) + x5 = xindex // ks3 + tmp0 = tl.load(in_ptr0 + (x4), xmask, eviction_policy='evict_last').to(tl.float32) + tmp1 = tl.load(in_ptr1 + (x2), xmask, eviction_policy='evict_last') + tmp2 = ks2 + tmp3 = tmp1 + tmp2 + tmp4 = tmp1 < 0 + tmp5 = tl.where(tmp4, tmp3, tmp1) + tl.device_assert(((0 <= tmp5) & (tmp5 < ks2)) | ~(xmask), "index out of bounds: 0 <= tmp5 < ks2") + tmp7 = tl.load(in_ptr2 + (x0 + ks3*tmp5), xmask, eviction_policy='evict_last').to(tl.float32) + tmp8 = tmp0 * tmp7 + tmp9 = x0 + tmp10 = tl.full([1], 0, tl.int64) + tmp11 = tmp9 >= tmp10 + tmp12 = ks3 + (-1)*(ks3 // 2) + tmp13 = tmp9 < tmp12 + tmp14 = tl.load(in_ptr0 + (ks3*x5 + (ks3 // 2) + (x0)), tmp13 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp15 = -tmp14 + tmp16 = tl.full(tmp15.shape, 0.0, tmp15.dtype) + tmp17 = tl.where(tmp13, tmp15, tmp16) + tmp18 = tmp9 >= tmp12 + tmp19 = ks3 + tmp20 = tmp9 < tmp19 + tmp21 = tl.load(in_ptr0 + (ks3*x5 + (x0 + ((-1)*ks3) + (ks3 // 2))), tmp18 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp22 = tl.where(tmp13, tmp17, tmp21) + tmp23 = ks4 + tmp24 = tmp1 + tmp23 + tmp25 = tl.where(tmp4, tmp24, tmp1) + tl.device_assert(((0 <= tmp25) & (tmp25 < ks4)) | ~(xmask), "index out of bounds: 0 <= tmp25 < ks4") + tmp27 = tl.load(in_ptr3 + (x0 + ks3*tmp25), xmask, eviction_policy='evict_last').to(tl.float32) + tmp28 = tmp22 * tmp27 + tmp29 = tmp8 + tmp28 + tl.store(out_ptr0 + (x4), tmp29, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/re/dff8d1ad795f6d3ea651c6bb3047a9ee76eb08ae69ebda1118c1e4d18d4c8269.best_config b/SpecForge-ext/cache/compiled_kernels/re/dff8d1ad795f6d3ea651c6bb3047a9ee76eb08ae69ebda1118c1e4d18d4c8269.best_config new file mode 100644 index 0000000000000000000000000000000000000000..2a95815b49cfc301dd2a3d06bb1b105b04bfbae7 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/re/dff8d1ad795f6d3ea651c6bb3047a9ee76eb08ae69ebda1118c1e4d18d4c8269.best_config @@ -0,0 +1 @@ +{"XBLOCK": 256, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 11, "triton_cache_hash": "XAIV2GWX5UZL7NNOCKNWC2I6AATKI6664P6FTQPRXS2M4AR4WJWA"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/rm/crmfgeenggpe7hoot35x4eji7uf7h6kj6uq5zcsn2zuahh3agba4.py b/SpecForge-ext/cache/compiled_kernels/rm/crmfgeenggpe7hoot35x4eji7uf7h6kj6uq5zcsn2zuahh3agba4.py new file mode 100644 index 0000000000000000000000000000000000000000..3a56359a76c043df2b8cac1a70406326bdced808 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/rm/crmfgeenggpe7hoot35x4eji7uf7h6kj6uq5zcsn2zuahh3agba4.py @@ -0,0 +1,47 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*i64', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_argmax_1(in_ptr0, out_ptr0, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % ks0) + x1 = xindex // ks0 + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + x3 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_2 + 32000*x0 + ks1*x1), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask & xmask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask & xmask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tl.store(out_ptr0 + (x3), tmp2, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/rr/crrkviqx4sevtl6x7gcbx7hgs3v3iisaodfab5qgfdcouplhadtr.py b/SpecForge-ext/cache/compiled_kernels/rr/crrkviqx4sevtl6x7gcbx7hgs3v3iisaodfab5qgfdcouplhadtr.py new file mode 100644 index 0000000000000000000000000000000000000000..1914ac2e177f5aa88c35398a201c62f5cbe993c6 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/rr/crrkviqx4sevtl6x7gcbx7hgs3v3iisaodfab5qgfdcouplhadtr.py @@ -0,0 +1,1051 @@ +# AOT ID: ['6_backward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/c3/cc3guwnwiox3yzzjtaquh6k4sm6nn4lcmkep56rop3grqr44xorh.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] +# Source node to ATen node mapping: +# Graph fragment: +# %getitem : Tensor "bf16[2, 32, 2048, 128][8388608, 128, 4096, 1]cuda:7" = PlaceHolder[target=getitem] +# %tangents_1 : Tensor "bf16[2, 32, 2048, 128][8388608, 262144, 128, 1]cuda:7" = PlaceHolder[target=tangents_1] +# %buf0 : Tensor "bf16[2, 32, 2048][65536, 2048, 1]cuda:7" = PlaceHolder[target=buf0] +# %full_default : Tensor "f32[2, 32, 2048][65536, 2048, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 32, 2048], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:7, pin_memory: False}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_1, %primals_2, %primals_3, %getitem, %getitem_1, %tangents_1, %full_default, %fw_graph0, %joint_graph0, (2048, 2048, %primals_5, %primals_4, %primals_7, %primals_8, %primals_9, %primals_10, %primals_11, %primals_12, 128, 128, %mask_graph0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_6,)), kwargs = {}) +# return %buf0,%buf1 +triton_red_fused_zeros_0 = async_compile.triton('triton_red_fused_zeros_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 131072, 'r0_': 128}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr1': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 1048576, 'r0_': 67108864}} +) +@triton.jit +def triton_red_fused_zeros_0(in_ptr0, in_ptr1, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 131072 + r0_numel = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % 2048) + x1 = ((xindex // 2048) % 32) + x2 = xindex // 65536 + x4 = xindex + _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_3 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_3 + 128*x1 + 4096*x0 + 8388608*x2), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_3 + 128*x4), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = _tmp4 + tmp3 + _tmp4 = tl.where(r0_mask, tmp5, _tmp4) + tmp4 = tl.sum(_tmp4, 1)[:, None] + tmp6 = tmp4.to(tl.float32) + tmp7 = 0.0 + tmp8 = tmp6 - tmp7 + tl.store(out_ptr1 + (x4), tmp8, None) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/yf/cyfx2gbe7xciksdo6za7eqxy3ntimrhw7eszk4msjxpi4gvt4qju.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] +# Source node to ATen node mapping: +# Graph fragment: +# %primals_1 : Tensor "bf16[2, 32, 2048, 128][8388608, 128, 4096, 1]cuda:7" = PlaceHolder[target=primals_1] +# %primals_2 : Tensor "bf16[2, 8, 2048, 128][2097152, 262144, 128, 1]cuda:7" = PlaceHolder[target=primals_2] +# %primals_3 : Tensor "bf16[2, 8, 2048, 128][2097152, 262144, 128, 1]cuda:7" = PlaceHolder[target=primals_3] +# %getitem_1 : Tensor "f32[2, 32, 2048][65536, 2048, 1]cuda:7" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[2, 32, 2048][65536, 2048, 1]cuda:7" = PlaceHolder[target=buf1] +# %tangents_1 : Tensor "bf16[2, 32, 2048, 128][8388608, 262144, 128, 1]cuda:7" = PlaceHolder[target=tangents_1] +# %getitem_3 : Tensor "bf16[2, 32, 2048, 128][8388608, 128, 4096, 1]cuda:7" = PlaceHolder[target=getitem_3] +# %getitem_5 : Tensor "bf16[2, 8, 2048, 128][2097152, 262144, 128, 1]cuda:7" = PlaceHolder[target=getitem_5] +# %primals_5 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:7" = PlaceHolder[target=primals_5] +# %primals_4 : Tensor "i32[2, 1, 16, 16][256, 256, 16, 1]cuda:7" = PlaceHolder[target=primals_4] +# %primals_9 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:7" = PlaceHolder[target=primals_9] +# %primals_10 : Tensor "i32[2, 1, 16, 16][256, 256, 16, 1]cuda:7" = PlaceHolder[target=primals_10] +# %primals_7 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:7" = PlaceHolder[target=primals_7] +# %primals_8 : Tensor "i32[2, 1, 16, 16][256, 256, 16, 1]cuda:7" = PlaceHolder[target=primals_8] +# %primals_11 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:7" = PlaceHolder[target=primals_11] +# %primals_12 : Tensor "i32[2, 1, 16, 16][256, 256, 16, 1]cuda:7" = PlaceHolder[target=primals_12] +# %primals_6 : Tensor "i64[2][1]cuda:7" = PlaceHolder[target=primals_6] +# %full_default : Tensor "f32[2, 32, 2048][65536, 2048, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 32, 2048], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:7, pin_memory: False}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_1, %primals_2, %primals_3, %getitem, %getitem_1, %tangents_1, %full_default, %fw_graph0, %joint_graph0, (2048, 2048, %primals_5, %primals_4, %primals_7, %primals_8, %primals_9, %primals_10, %primals_11, %primals_12, 128, 128, %mask_graph0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_6,)), kwargs = {}) +# return %getitem_4 +triton_tem_fused_zeros_1 = async_compile.triton('triton_tem_fused_zeros_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_zeros_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 2097152, 262144, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 2097152, 262144, 128, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 8388608, 262144, 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 8388608, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 2097152, 262144, 128, 1 + + ZQ = 2 + HQ = 32 + HKV = 8 + Q_LEN = 2048 + ZKV = 2 + KV_LEN = 2048 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 2 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 256 + stride_kv_idx_m = 16 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 16 + stride_q_idx_h = 256 + stride_q_idx_n = 16 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 262144*off_hkv + 2097152*off_zq + tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = 2048 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr16 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp39 = (ds) + grad_scores = tmp39 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = 2048 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp40 = (qkT) + post_mod_scores = tmp40 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp41 = tl.full([1], False, tl.int1) + tmp42 = (m) + tmp43 = (n) + tmp44 = tmp42 >= tmp43 + tmp45 = tmp43.to(tl.int64) + tmp46 = (off_z) + tmp47 = tl.load(in_ptr16 + tmp46) + tmp48 = tmp45 < tmp47 + tmp49 = tmp42.to(tl.int64) + tmp50 = tmp49 < tmp47 + tmp51 = tmp48 & tmp50 + tmp52 = tmp44 & tmp51 + tmp53 = tmp41 | tmp52 + tmp54 = tl.full([1], 2048, tl.int32) + tmp55 = tmp43 >= tmp54 + tmp56 = (tmp43 % tmp54) + tmp57 = tl.full([1], 0, tl.int32) + tmp58 = tmp56 != tmp57 + tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0 + tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0 + tmp61 = tmp59 != tmp60 + tmp62 = tmp58 & tmp61 + tmp63 = tmp56 + tmp54 + tmp64 = tl.where(tmp62, tmp63, tmp56) + tmp65 = tmp64.to(tl.int64) + tmp66 = tmp65 < tmp47 + tmp67 = tmp55 & tmp66 + tmp68 = tmp43 - tmp42 + tmp69 = (tmp68 % tmp54) + tmp70 = tmp69 != tmp57 + tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0 + tmp72 = tmp71 != tmp60 + tmp73 = tmp70 & tmp72 + tmp74 = tmp69 + tmp54 + tmp75 = tl.where(tmp73, tmp74, tmp69) + tmp76 = tmp75 == tmp57 + tmp77 = tmp67 & tmp76 + tmp78 = tmp53 | tmp77 + mask_mod_output = tmp78 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp79 = (dsT) + grad_scores = tmp79 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, getitem, getitem_1, tangents_1 = args + args.clear() + assert_size_stride(primals_1, (2, 32, 2048, 128), (8388608, 128, 4096, 1)) + assert_size_stride(primals_2, (2, 8, 2048, 128), (2097152, 262144, 128, 1)) + assert_size_stride(primals_3, (2, 8, 2048, 128), (2097152, 262144, 128, 1)) + assert_size_stride(primals_4, (2, 1, 16, 16), (256, 256, 16, 1)) + assert_size_stride(primals_5, (2, 1, 16), (16, 16, 1)) + assert_size_stride(primals_6, (2, ), (1, )) + assert_size_stride(primals_7, (2, 1, 16), (16, 16, 1)) + assert_size_stride(primals_8, (2, 1, 16, 16), (256, 256, 16, 1)) + assert_size_stride(primals_9, (2, 1, 16), (16, 16, 1)) + assert_size_stride(primals_10, (2, 1, 16, 16), (256, 256, 16, 1)) + assert_size_stride(primals_11, (2, 1, 16), (16, 16, 1)) + assert_size_stride(primals_12, (2, 1, 16, 16), (256, 256, 16, 1)) + assert_size_stride(getitem, (2, 32, 2048, 128), (8388608, 128, 4096, 1)) + assert_size_stride(getitem_1, (2, 32, 2048), (65536, 2048, 1)) + assert_size_stride(tangents_1, (2, 32, 2048, 128), (8388608, 262144, 128, 1)) + with torch.cuda._DeviceGuard(7): + torch.cuda.set_device(7) + buf1 = empty_strided_cuda((2, 32, 2048), (65536, 2048, 1), torch.float32) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] + stream7 = get_raw_stream(7) + triton_red_fused_zeros_0.run(getitem, tangents_1, buf1, 131072, 128, stream=stream7) + del getitem + buf3 = empty_strided_cuda((2, 32, 2048, 128), (8388608, 128, 4096, 1), torch.bfloat16) + buf4 = empty_strided_cuda((2, 8, 2048, 128), (2097152, 262144, 128, 1), torch.bfloat16) + buf5 = empty_strided_cuda((2, 8, 2048, 128), (2097152, 262144, 128, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] + stream7 = get_raw_stream(7) + triton_tem_fused_zeros_1.run(primals_1, primals_2, primals_3, getitem_1, buf1, tangents_1, buf3, buf4, primals_5, primals_4, primals_9, primals_10, primals_7, primals_8, primals_11, primals_12, primals_6, buf5, 80, 2, 8, stream=stream7) + del buf1 + del getitem_1 + del primals_1 + del primals_10 + del primals_11 + del primals_12 + del primals_2 + del primals_3 + del primals_4 + del primals_5 + del primals_6 + del primals_7 + del primals_8 + del primals_9 + del tangents_1 + return (buf3, buf5, buf4, None, None, None, None, None, None, None, None, None, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = rand_strided((2, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:7', dtype=torch.bfloat16) + primals_2 = rand_strided((2, 8, 2048, 128), (2097152, 262144, 128, 1), device='cuda:7', dtype=torch.bfloat16) + primals_3 = rand_strided((2, 8, 2048, 128), (2097152, 262144, 128, 1), device='cuda:7', dtype=torch.bfloat16) + primals_4 = rand_strided((2, 1, 16, 16), (256, 256, 16, 1), device='cuda:7', dtype=torch.int32) + primals_5 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:7', dtype=torch.int32) + primals_6 = rand_strided((2, ), (1, ), device='cuda:7', dtype=torch.int64) + primals_7 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:7', dtype=torch.int32) + primals_8 = rand_strided((2, 1, 16, 16), (256, 256, 16, 1), device='cuda:7', dtype=torch.int32) + primals_9 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:7', dtype=torch.int32) + primals_10 = rand_strided((2, 1, 16, 16), (256, 256, 16, 1), device='cuda:7', dtype=torch.int32) + primals_11 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:7', dtype=torch.int32) + primals_12 = rand_strided((2, 1, 16, 16), (256, 256, 16, 1), device='cuda:7', dtype=torch.int32) + getitem = rand_strided((2, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:7', dtype=torch.bfloat16) + getitem_1 = rand_strided((2, 32, 2048), (65536, 2048, 1), device='cuda:7', dtype=torch.float32) + tangents_1 = rand_strided((2, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:7', dtype=torch.bfloat16) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, getitem, getitem_1, tangents_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/rr/crrus5k7xklfjaq2xlcks4vmdztndb3meqkvhq6xjel5tzpesulf.py b/SpecForge-ext/cache/compiled_kernels/rr/crrus5k7xklfjaq2xlcks4vmdztndb3meqkvhq6xjel5tzpesulf.py new file mode 100644 index 0000000000000000000000000000000000000000..97ec2ef8c1124e18434d9086a45225ca273f3d6b --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/rr/crrus5k7xklfjaq2xlcks4vmdztndb3meqkvhq6xjel5tzpesulf.py @@ -0,0 +1,44 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 1, 'r0_': 4096}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i64', 'in_ptr2': '*i64', 'out_ptr0': '*i64', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {'xnumel': 1}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_eq_mul_squeeze_sum_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_eq_mul_squeeze_sum_2(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 1 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + _tmp7 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_0 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.load(in_ptr1 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp4 = tl.load(in_ptr2 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp2 = tmp0 == tmp1 + tmp3 = tmp2.to(tl.int64) + tmp5 = tmp3 * tmp4 + tmp6 = tl.broadcast_to(tmp5, [XBLOCK, R0_BLOCK]) + tmp8 = _tmp7 + tmp6 + _tmp7 = tl.where(r0_mask, tmp8, _tmp7) + tmp7 = tl.sum(_tmp7, 1)[:, None] + tl.store(out_ptr0 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp7, None) diff --git a/SpecForge-ext/cache/compiled_kernels/rz/crzszi74kpu7oi7kcekyvxjo2zlnsomuo7nye74lm2jrnhee2o7l.py b/SpecForge-ext/cache/compiled_kernels/rz/crzszi74kpu7oi7kcekyvxjo2zlnsomuo7nye74lm2jrnhee2o7l.py new file mode 100644 index 0000000000000000000000000000000000000000..4e635622cce9d77ef24f97503f58623bdef76e2d --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/rz/crzszi74kpu7oi7kcekyvxjo2zlnsomuo7nye74lm2jrnhee2o7l.py @@ -0,0 +1,552 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'in_ptr9': '*i64', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 2097152, 262144, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 2097152, 262144, 128, 1 + + ZQ = 8 + HQ = 32 + Q_LEN = 2048 + ZKV = 8 + KV_LEN = 2048 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 8 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 256 + stride_kv_idx_m = 16 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 262144*idx_hq + 8388608*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m + 8388608*idx_zq, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr9 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/s4/cs4afdu7ezaeekoshfdryoga6jabuq2nrx5xdkgxrehrkuvy5jri.py b/SpecForge-ext/cache/compiled_kernels/s4/cs4afdu7ezaeekoshfdryoga6jabuq2nrx5xdkgxrehrkuvy5jri.py new file mode 100644 index 0000000000000000000000000000000000000000..cb137c95580d79e7403a9ac2088192f7d2ef23a7 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/s4/cs4afdu7ezaeekoshfdryoga6jabuq2nrx5xdkgxrehrkuvy5jri.py @@ -0,0 +1,49 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 131072, 'r0_': 128}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr1': '*fp32', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_zeros_0(in_ptr0, in_ptr1, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % ks0) + x1 = ((xindex // ks0) % 32) + x2 = xindex // ks1 + x5 = triton_helpers.div_floor_integer(xindex, ks0) + _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + x4 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_3 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_3 + 128*x1 + 4096*x0 + 4096*ks0*x2), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_3 + 128*x0 + 128*x5*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = _tmp4 + tmp3 + _tmp4 = tl.where(r0_mask & xmask, tmp5, _tmp4) + tmp4 = tl.sum(_tmp4, 1)[:, None] + tmp6 = tmp4.to(tl.float32) + tmp7 = 0.0 + tmp8 = tmp6 - tmp7 + tl.store(out_ptr1 + (x4), tmp8, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/s4/ed1e011c9106623d0084987034d839c6c08de94ff77afd907f394d282f34f7b1.best_config b/SpecForge-ext/cache/compiled_kernels/s4/ed1e011c9106623d0084987034d839c6c08de94ff77afd907f394d282f34f7b1.best_config new file mode 100644 index 0000000000000000000000000000000000000000..5a4516876787361d952a91f80b4c5e6d16b6ce65 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/s4/ed1e011c9106623d0084987034d839c6c08de94ff77afd907f394d282f34f7b1.best_config @@ -0,0 +1 @@ +{"XBLOCK": 64, "R0_BLOCK": 64, "num_warps": 16, "num_stages": 1, "configs_hash": "48464ea7d171263ae4fed5184e32a30841f1081b8df295ec1f8e2f76e5287c9d", "found_by_coordesc": false, "time_taken_ms": 61, "triton_cache_hash": "BXWZSSWKBTIG7YDOE6QDLF3DYUHLUN57GPEDYW37ZDRQO2XWRGCQ"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/s6/7f95968e11ed97f7130f87d331d4ca9eba21ae7f23d9af09e3acc9796f8af34b.best_config b/SpecForge-ext/cache/compiled_kernels/s6/7f95968e11ed97f7130f87d331d4ca9eba21ae7f23d9af09e3acc9796f8af34b.best_config new file mode 100644 index 0000000000000000000000000000000000000000..a337a719c6503c8dcbad0c427c4a5067600d0bd0 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/s6/7f95968e11ed97f7130f87d331d4ca9eba21ae7f23d9af09e3acc9796f8af34b.best_config @@ -0,0 +1 @@ +{"XBLOCK": 256, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 11, "triton_cache_hash": "6FB7I6IASCIGI3DSKLBL4Q2CXFFWPYWXW7AMHNUUDLPGKUCB3PDA"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/si/00bd1ea0d84eb0da96a9f0379d19e1b25030faeec599b79b07bb27746b186dc6.best_config b/SpecForge-ext/cache/compiled_kernels/si/00bd1ea0d84eb0da96a9f0379d19e1b25030faeec599b79b07bb27746b186dc6.best_config new file mode 100644 index 0000000000000000000000000000000000000000..c1c51c5048e176f0cf0b0d2646bd98c4186a3cba --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/si/00bd1ea0d84eb0da96a9f0379d19e1b25030faeec599b79b07bb27746b186dc6.best_config @@ -0,0 +1 @@ +{"XBLOCK": 8, "R0_BLOCK": 128, "num_warps": 8, "num_stages": 1, "configs_hash": "48464ea7d171263ae4fed5184e32a30841f1081b8df295ec1f8e2f76e5287c9d", "found_by_coordesc": false, "time_taken_ms": 60, "triton_cache_hash": "EGDJYO36DUYGK3UQBUH6S7RMVKF77GGHWVMFFZR5R4TDMIZ4YVJA"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/si/330dd62e96df95234f6dae47539a12a239a40e0b519e106052a0304511009b0d.best_config b/SpecForge-ext/cache/compiled_kernels/si/330dd62e96df95234f6dae47539a12a239a40e0b519e106052a0304511009b0d.best_config new file mode 100644 index 0000000000000000000000000000000000000000..6f71eaf82c5e74dc5aa6a4789bcbb6adf0a92d55 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/si/330dd62e96df95234f6dae47539a12a239a40e0b519e106052a0304511009b0d.best_config @@ -0,0 +1 @@ +{"XBLOCK": 1, "R0_BLOCK": 16, "num_warps": 2, "num_stages": 1, "configs_hash": "9889a3900cf19f2f3cbdf50dfff07c1cd9bb504be42b4c95a8b2b6f156e5f333", "found_by_coordesc": false, "time_taken_ms": 34, "triton_cache_hash": "CSBUDPF5G22GDISE2XI2DNQQGLLFGYBXWKH7266NZF7EZHMJNZGA"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/si/a0663295c70aebd52ccbd8f9b19b8f13c79a4ac4d3ea6dd038b1dc41a56a909b.best_config b/SpecForge-ext/cache/compiled_kernels/si/a0663295c70aebd52ccbd8f9b19b8f13c79a4ac4d3ea6dd038b1dc41a56a909b.best_config new file mode 100644 index 0000000000000000000000000000000000000000..bccc339c530946640745852b622c73570887ab86 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/si/a0663295c70aebd52ccbd8f9b19b8f13c79a4ac4d3ea6dd038b1dc41a56a909b.best_config @@ -0,0 +1 @@ +{"XBLOCK": 128, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 11, "triton_cache_hash": "CLTRXNE5MHPP3O5A5W4Z4EQTTZVYMOP5IPJT6N44O6FTBZFXLMNA"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/si/csifv3wugwezbiapa7tfir7475ta4akdayy7mujclukw2elqooca.py b/SpecForge-ext/cache/compiled_kernels/si/csifv3wugwezbiapa7tfir7475ta4akdayy7mujclukw2elqooca.py new file mode 100644 index 0000000000000000000000000000000000000000..39f33b0e78744c494050345d42af24766148582a --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/si/csifv3wugwezbiapa7tfir7475ta4akdayy7mujclukw2elqooca.py @@ -0,0 +1,26 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 512}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr0': '*i32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_clone_slice_4', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_clone_slice_4(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = (xindex % ks0) + x1 = xindex // ks0 + x2 = xindex + tmp0 = tl.load(in_ptr0 + (x0 + x1 + ks0*x1), xmask, eviction_policy='evict_last') + tl.store(out_ptr0 + (x2), tmp0, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/si/csikk3n53lwihovd25mpe5kyjf7hnym4zk3xgfxvpcikm4bgg35n.py b/SpecForge-ext/cache/compiled_kernels/si/csikk3n53lwihovd25mpe5kyjf7hnym4zk3xgfxvpcikm4bgg35n.py new file mode 100644 index 0000000000000000000000000000000000000000..58bb080cdb3d6666af09b7ddd5955f6a41a892c3 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/si/csikk3n53lwihovd25mpe5kyjf7hnym4zk3xgfxvpcikm4bgg35n.py @@ -0,0 +1,43 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 128, 'r0_': 16}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr1': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_clone_slice_sum_transpose_5', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_clone_slice_sum_transpose_5(in_ptr0, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % ks0) + x1 = xindex // ks0 + _tmp3 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + x3 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (x0 + ks0*r0_2 + ks0*ks1*x1), r0_mask & xmask, eviction_policy='evict_last', other=0.0) + tmp1 = tmp0.to(tl.int64) + tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) + tmp4 = _tmp3 + tmp2 + _tmp3 = tl.where(r0_mask & xmask, tmp4, _tmp3) + tmp3 = tl.sum(_tmp3, 1)[:, None] + tmp5 = tmp3.to(tl.int32) + tl.store(out_ptr1 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp5, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/si/csimgsywujh2hsabqlfrkiolvsb67iqr7fvv654a3o76p2cedho4.py b/SpecForge-ext/cache/compiled_kernels/si/csimgsywujh2hsabqlfrkiolvsb67iqr7fvv654a3o76p2cedho4.py new file mode 100644 index 0000000000000000000000000000000000000000..45fc7f7a0b254c28127fda20e98fc54ba84ab066 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/si/csimgsywujh2hsabqlfrkiolvsb67iqr7fvv654a3o76p2cedho4.py @@ -0,0 +1,835 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_zeros_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks0, 128*ks0, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks0, 128*ks0, 128, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 8388608, 262144, 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 8388608, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks0, 128*ks0, 128, 1 + + ZQ = 2 + HQ = 32 + HKV = 8 + Q_LEN = 2048 + ZKV = 2 + KV_LEN = ks0 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 2 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 16*ks1 + stride_kv_idx_m = ks1 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = ks2 + stride_q_idx_h = 16*ks3 + stride_q_idx_n = 16 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks0 + 1024*off_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = ks0 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr16 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp39 = (ds) + grad_scores = tmp39 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = ks0 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp40 = (qkT) + post_mod_scores = tmp40 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp41 = tl.full([1], False, tl.int1) + tmp42 = (m) + tmp43 = (n) + tmp44 = tmp42 >= tmp43 + tmp45 = tmp43.to(tl.int64) + tmp46 = (off_z) + tmp47 = tl.load(in_ptr16 + tmp46) + tmp48 = tmp45 < tmp47 + tmp49 = tmp42.to(tl.int64) + tmp50 = tmp49 < tmp47 + tmp51 = tmp48 & tmp50 + tmp52 = tmp44 & tmp51 + tmp53 = tmp41 | tmp52 + tmp54 = tl.full([1], 2048, tl.int32) + tmp55 = tmp43 >= tmp54 + tmp56 = (tmp43 % tmp54) + tmp57 = tl.full([1], 0, tl.int32) + tmp58 = tmp56 != tmp57 + tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0 + tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0 + tmp61 = tmp59 != tmp60 + tmp62 = tmp58 & tmp61 + tmp63 = tmp56 + tmp54 + tmp64 = tl.where(tmp62, tmp63, tmp56) + tmp65 = tmp64.to(tl.int64) + tmp66 = tmp65 < tmp47 + tmp67 = tmp55 & tmp66 + tmp68 = tmp43 - tmp42 + tmp69 = (tmp68 % tmp54) + tmp70 = tmp69 != tmp57 + tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0 + tmp72 = tmp71 != tmp60 + tmp73 = tmp70 & tmp72 + tmp74 = tmp69 + tmp54 + tmp75 = tl.where(tmp73, tmp74, tmp69) + tmp76 = tmp75 == tmp57 + tmp77 = tmp67 & tmp76 + tmp78 = tmp53 | tmp77 + mask_mod_output = tmp78 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp79 = (dsT) + grad_scores = tmp79 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/si/csivid7ys23us3bz753ofgfyl6kefmzjfmymnzsvs4zosyg73h6z.py b/SpecForge-ext/cache/compiled_kernels/si/csivid7ys23us3bz753ofgfyl6kefmzjfmymnzsvs4zosyg73h6z.py new file mode 100644 index 0000000000000000000000000000000000000000..45b80f5405d44f36b0291a9e59ff4d6dae45ca85 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/si/csivid7ys23us3bz753ofgfyl6kefmzjfmymnzsvs4zosyg73h6z.py @@ -0,0 +1,49 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 131072, 'r0_': 128}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr1': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 1048576, 'r0_': 67108864}} +) +@triton.jit +def triton_red_fused_zeros_0(in_ptr0, in_ptr1, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 131072 + r0_numel = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % 2048) + x1 = ((xindex // 2048) % 32) + x2 = xindex // 65536 + x4 = xindex + _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_3 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_3 + 128*x1 + 4096*x0 + 8388608*x2), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_3 + 128*x4), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = _tmp4 + tmp3 + _tmp4 = tl.where(r0_mask, tmp5, _tmp4) + tmp4 = tl.sum(_tmp4, 1)[:, None] + tmp6 = tmp4.to(tl.float32) + tmp7 = 0.0 + tmp8 = tmp6 - tmp7 + tl.store(out_ptr1 + (x4), tmp8, None) diff --git a/SpecForge-ext/cache/compiled_kernels/t4/940bf06436a36c0021994f736bf764744ff4a5312dacac6f4772cf7aac0855ce.best_config b/SpecForge-ext/cache/compiled_kernels/t4/940bf06436a36c0021994f736bf764744ff4a5312dacac6f4772cf7aac0855ce.best_config new file mode 100644 index 0000000000000000000000000000000000000000..a013ad4c2a7b9a18e1d475008c8b3e320dca3141 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/t4/940bf06436a36c0021994f736bf764744ff4a5312dacac6f4772cf7aac0855ce.best_config @@ -0,0 +1 @@ +{"XBLOCK": 512, "num_warps": 8, "num_stages": 1, "configs_hash": "3ca5c3e34d35093f3c9ab2829a9faeebad5e61c4ca13d5ed6053d7b71ce60d5a", "found_by_coordesc": false, "time_taken_ms": 52, "triton_cache_hash": "UQSFYICF6CFQWZOBHCGZ7JZ457GHWVO6RMPN5ABNWOATFMKI6GQA"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/t4/ct4ja5lnaomv5oj7757f5xs5m47uf73w3ia3qy4uznjw6fr7z7gi.py b/SpecForge-ext/cache/compiled_kernels/t4/ct4ja5lnaomv5oj7757f5xs5m47uf73w3ia3qy4uznjw6fr7z7gi.py new file mode 100644 index 0000000000000000000000000000000000000000..b924b7e58c58bae0b104fd566792b0c4b3c8835b --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/t4/ct4ja5lnaomv5oj7757f5xs5m47uf73w3ia3qy4uznjw6fr7z7gi.py @@ -0,0 +1,66 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 67108864}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 6, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = (xindex % ks0) + x3 = xindex + x1 = ((xindex // ks0) % ks1) + tmp31 = tl.load(in_ptr0 + (x3), xmask, eviction_policy='evict_last').to(tl.float32) + tmp32 = tl.load(in_ptr1 + (x1), xmask, eviction_policy='evict_last') + tmp0 = x0 + tmp1 = ks0 // 2 + tmp2 = tmp0 >= tmp1 + tmp3 = tl.load(in_ptr0 + (x3 + (-1)*(ks0 // 2)), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp4 = tl.load(in_ptr1 + (x1), tmp2 & xmask, eviction_policy='evict_last', other=0.0) + tmp5 = tl.broadcast_to(ks2, [XBLOCK]) + tmp6 = tmp4 + tmp5 + tmp7 = tmp4 < 0 + tmp8 = tl.where(tmp7, tmp6, tmp4) + tl.device_assert(((0 <= tl.broadcast_to(tmp8, [XBLOCK])) & (tl.broadcast_to(tmp8, [XBLOCK]) < ks2)) | ~(tmp2 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp8, [XBLOCK]) < ks2") + tmp10 = tl.load(in_ptr2 + (x0 + (-1)*(ks0 // 2) + ks0*tmp8), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp11 = tmp3 * tmp10 + tmp12 = -tmp11 + tmp13 = tl.full(tmp12.shape, 0.0, tmp12.dtype) + tmp14 = tl.where(tmp2, tmp12, tmp13) + tmp15 = 0.0 + tmp16 = tl.where(tmp2, tmp14, tmp15) + tmp17 = tmp0 < tmp1 + tmp18 = tl.load(in_ptr0 + (ks0 + x3 + (-1)*(ks0 // 2)), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp19 = tl.load(in_ptr1 + (x1), tmp17 & xmask, eviction_policy='evict_last', other=0.0) + tmp20 = tl.broadcast_to(ks2, [XBLOCK]) + tmp21 = tmp19 + tmp20 + tmp22 = tmp19 < 0 + tmp23 = tl.where(tmp22, tmp21, tmp19) + tl.device_assert(((0 <= tl.broadcast_to(tmp23, [XBLOCK])) & (tl.broadcast_to(tmp23, [XBLOCK]) < ks2)) | ~(tmp17 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp23, [XBLOCK]) < ks2") + tmp25 = tl.load(in_ptr2 + (ks0 + x0 + (-1)*(ks0 // 2) + ks0*tmp23), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp26 = tmp18 * tmp25 + tmp27 = tl.full(tmp26.shape, 0.0, tmp26.dtype) + tmp28 = tl.where(tmp17, tmp26, tmp27) + tmp29 = tl.where(tmp17, tmp28, tmp15) + tmp30 = tmp16 + tmp29 + tmp33 = ks3 + tmp34 = tmp32 + tmp33 + tmp35 = tmp32 < 0 + tmp36 = tl.where(tmp35, tmp34, tmp32) + tl.device_assert(((0 <= tmp36) & (tmp36 < ks3)) | ~(xmask), "index out of bounds: 0 <= tmp36 < ks3") + tmp38 = tl.load(in_ptr3 + (x0 + ks0*tmp36), xmask, eviction_policy='evict_last').to(tl.float32) + tmp39 = tmp31 * tmp38 + tmp40 = tmp30 + tmp39 + tl.store(out_ptr0 + (x3), tmp40, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/tg/ctgvjtmmibb2qmiq3ddlrosqut7a6nt4ofshy2uyjanfab2rlaqe.py b/SpecForge-ext/cache/compiled_kernels/tg/ctgvjtmmibb2qmiq3ddlrosqut7a6nt4ofshy2uyjanfab2rlaqe.py new file mode 100644 index 0000000000000000000000000000000000000000..f8c420975040a810f7213dd9c79f182a9ff5c5a1 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/tg/ctgvjtmmibb2qmiq3ddlrosqut7a6nt4ofshy2uyjanfab2rlaqe.py @@ -0,0 +1,1051 @@ +# AOT ID: ['6_backward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/og/cogol55cthk4zevsy3dlqiyzipefv735ge2wtaddvk436qht5nox.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] +# Source node to ATen node mapping: +# Graph fragment: +# %getitem : Tensor "bf16[2, 32, 2048, 128][8388608, 128, 4096, 1]cuda:6" = PlaceHolder[target=getitem] +# %tangents_1 : Tensor "bf16[2, 32, 2048, 128][8388608, 262144, 128, 1]cuda:6" = PlaceHolder[target=tangents_1] +# %buf0 : Tensor "bf16[2, 32, 2048][65536, 2048, 1]cuda:6" = PlaceHolder[target=buf0] +# %full_default : Tensor "f32[2, 32, 2048][65536, 2048, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 32, 2048], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:6, pin_memory: False}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_1, %primals_2, %primals_3, %getitem, %getitem_1, %tangents_1, %full_default, %fw_graph0, %joint_graph0, (2048, 2048, %primals_5, %primals_4, %primals_7, %primals_8, %primals_9, %primals_10, %primals_11, %primals_12, 128, 128, %mask_graph0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_6,)), kwargs = {}) +# return %buf0,%buf1 +triton_red_fused_zeros_0 = async_compile.triton('triton_red_fused_zeros_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 131072, 'r0_': 128}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr1': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 1048576, 'r0_': 67108864}} +) +@triton.jit +def triton_red_fused_zeros_0(in_ptr0, in_ptr1, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 131072 + r0_numel = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % 2048) + x1 = ((xindex // 2048) % 32) + x2 = xindex // 65536 + x4 = xindex + _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_3 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_3 + 128*x1 + 4096*x0 + 8388608*x2), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_3 + 128*x4), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = _tmp4 + tmp3 + _tmp4 = tl.where(r0_mask, tmp5, _tmp4) + tmp4 = tl.sum(_tmp4, 1)[:, None] + tmp6 = tmp4.to(tl.float32) + tmp7 = 0.0 + tmp8 = tmp6 - tmp7 + tl.store(out_ptr1 + (x4), tmp8, None) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/iu/ciubsrlh6ebtmbsanuldd2ketrtc6ptf3mgb7oftea6g5arlci73.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] +# Source node to ATen node mapping: +# Graph fragment: +# %primals_1 : Tensor "bf16[2, 32, 2048, 128][8388608, 128, 4096, 1]cuda:6" = PlaceHolder[target=primals_1] +# %primals_2 : Tensor "bf16[2, 8, 2048, 128][2097152, 262144, 128, 1]cuda:6" = PlaceHolder[target=primals_2] +# %primals_3 : Tensor "bf16[2, 8, 2048, 128][2097152, 262144, 128, 1]cuda:6" = PlaceHolder[target=primals_3] +# %getitem_1 : Tensor "f32[2, 32, 2048][65536, 2048, 1]cuda:6" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[2, 32, 2048][65536, 2048, 1]cuda:6" = PlaceHolder[target=buf1] +# %tangents_1 : Tensor "bf16[2, 32, 2048, 128][8388608, 262144, 128, 1]cuda:6" = PlaceHolder[target=tangents_1] +# %getitem_3 : Tensor "bf16[2, 32, 2048, 128][8388608, 128, 4096, 1]cuda:6" = PlaceHolder[target=getitem_3] +# %getitem_5 : Tensor "bf16[2, 8, 2048, 128][2097152, 262144, 128, 1]cuda:6" = PlaceHolder[target=getitem_5] +# %primals_5 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:6" = PlaceHolder[target=primals_5] +# %primals_4 : Tensor "i32[2, 1, 16, 16][256, 256, 16, 1]cuda:6" = PlaceHolder[target=primals_4] +# %primals_9 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:6" = PlaceHolder[target=primals_9] +# %primals_10 : Tensor "i32[2, 1, 16, 16][256, 256, 16, 1]cuda:6" = PlaceHolder[target=primals_10] +# %primals_7 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:6" = PlaceHolder[target=primals_7] +# %primals_8 : Tensor "i32[2, 1, 16, 16][256, 256, 16, 1]cuda:6" = PlaceHolder[target=primals_8] +# %primals_11 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:6" = PlaceHolder[target=primals_11] +# %primals_12 : Tensor "i32[2, 1, 16, 16][256, 256, 16, 1]cuda:6" = PlaceHolder[target=primals_12] +# %primals_6 : Tensor "i64[2][1]cuda:6" = PlaceHolder[target=primals_6] +# %full_default : Tensor "f32[2, 32, 2048][65536, 2048, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 32, 2048], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:6, pin_memory: False}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_1, %primals_2, %primals_3, %getitem, %getitem_1, %tangents_1, %full_default, %fw_graph0, %joint_graph0, (2048, 2048, %primals_5, %primals_4, %primals_7, %primals_8, %primals_9, %primals_10, %primals_11, %primals_12, 128, 128, %mask_graph0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_6,)), kwargs = {}) +# return %getitem_4 +triton_tem_fused_zeros_1 = async_compile.triton('triton_tem_fused_zeros_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_zeros_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 2097152, 262144, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 2097152, 262144, 128, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 8388608, 262144, 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 8388608, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 2097152, 262144, 128, 1 + + ZQ = 2 + HQ = 32 + HKV = 8 + Q_LEN = 2048 + ZKV = 2 + KV_LEN = 2048 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 2 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 256 + stride_kv_idx_m = 16 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 16 + stride_q_idx_h = 256 + stride_q_idx_n = 16 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 262144*off_hkv + 2097152*off_zq + tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = 2048 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr16 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp39 = (ds) + grad_scores = tmp39 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = 2048 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp40 = (qkT) + post_mod_scores = tmp40 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp41 = tl.full([1], False, tl.int1) + tmp42 = (m) + tmp43 = (n) + tmp44 = tmp42 >= tmp43 + tmp45 = tmp43.to(tl.int64) + tmp46 = (off_z) + tmp47 = tl.load(in_ptr16 + tmp46) + tmp48 = tmp45 < tmp47 + tmp49 = tmp42.to(tl.int64) + tmp50 = tmp49 < tmp47 + tmp51 = tmp48 & tmp50 + tmp52 = tmp44 & tmp51 + tmp53 = tmp41 | tmp52 + tmp54 = tl.full([1], 2048, tl.int32) + tmp55 = tmp43 >= tmp54 + tmp56 = (tmp43 % tmp54) + tmp57 = tl.full([1], 0, tl.int32) + tmp58 = tmp56 != tmp57 + tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0 + tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0 + tmp61 = tmp59 != tmp60 + tmp62 = tmp58 & tmp61 + tmp63 = tmp56 + tmp54 + tmp64 = tl.where(tmp62, tmp63, tmp56) + tmp65 = tmp64.to(tl.int64) + tmp66 = tmp65 < tmp47 + tmp67 = tmp55 & tmp66 + tmp68 = tmp43 - tmp42 + tmp69 = (tmp68 % tmp54) + tmp70 = tmp69 != tmp57 + tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0 + tmp72 = tmp71 != tmp60 + tmp73 = tmp70 & tmp72 + tmp74 = tmp69 + tmp54 + tmp75 = tl.where(tmp73, tmp74, tmp69) + tmp76 = tmp75 == tmp57 + tmp77 = tmp67 & tmp76 + tmp78 = tmp53 | tmp77 + mask_mod_output = tmp78 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp79 = (dsT) + grad_scores = tmp79 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, getitem, getitem_1, tangents_1 = args + args.clear() + assert_size_stride(primals_1, (2, 32, 2048, 128), (8388608, 128, 4096, 1)) + assert_size_stride(primals_2, (2, 8, 2048, 128), (2097152, 262144, 128, 1)) + assert_size_stride(primals_3, (2, 8, 2048, 128), (2097152, 262144, 128, 1)) + assert_size_stride(primals_4, (2, 1, 16, 16), (256, 256, 16, 1)) + assert_size_stride(primals_5, (2, 1, 16), (16, 16, 1)) + assert_size_stride(primals_6, (2, ), (1, )) + assert_size_stride(primals_7, (2, 1, 16), (16, 16, 1)) + assert_size_stride(primals_8, (2, 1, 16, 16), (256, 256, 16, 1)) + assert_size_stride(primals_9, (2, 1, 16), (16, 16, 1)) + assert_size_stride(primals_10, (2, 1, 16, 16), (256, 256, 16, 1)) + assert_size_stride(primals_11, (2, 1, 16), (16, 16, 1)) + assert_size_stride(primals_12, (2, 1, 16, 16), (256, 256, 16, 1)) + assert_size_stride(getitem, (2, 32, 2048, 128), (8388608, 128, 4096, 1)) + assert_size_stride(getitem_1, (2, 32, 2048), (65536, 2048, 1)) + assert_size_stride(tangents_1, (2, 32, 2048, 128), (8388608, 262144, 128, 1)) + with torch.cuda._DeviceGuard(6): + torch.cuda.set_device(6) + buf1 = empty_strided_cuda((2, 32, 2048), (65536, 2048, 1), torch.float32) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] + stream6 = get_raw_stream(6) + triton_red_fused_zeros_0.run(getitem, tangents_1, buf1, 131072, 128, stream=stream6) + del getitem + buf3 = empty_strided_cuda((2, 32, 2048, 128), (8388608, 128, 4096, 1), torch.bfloat16) + buf4 = empty_strided_cuda((2, 8, 2048, 128), (2097152, 262144, 128, 1), torch.bfloat16) + buf5 = empty_strided_cuda((2, 8, 2048, 128), (2097152, 262144, 128, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] + stream6 = get_raw_stream(6) + triton_tem_fused_zeros_1.run(primals_1, primals_2, primals_3, getitem_1, buf1, tangents_1, buf3, buf4, primals_5, primals_4, primals_9, primals_10, primals_7, primals_8, primals_11, primals_12, primals_6, buf5, 80, 2, 8, stream=stream6) + del buf1 + del getitem_1 + del primals_1 + del primals_10 + del primals_11 + del primals_12 + del primals_2 + del primals_3 + del primals_4 + del primals_5 + del primals_6 + del primals_7 + del primals_8 + del primals_9 + del tangents_1 + return (buf3, buf5, buf4, None, None, None, None, None, None, None, None, None, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = rand_strided((2, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:6', dtype=torch.bfloat16) + primals_2 = rand_strided((2, 8, 2048, 128), (2097152, 262144, 128, 1), device='cuda:6', dtype=torch.bfloat16) + primals_3 = rand_strided((2, 8, 2048, 128), (2097152, 262144, 128, 1), device='cuda:6', dtype=torch.bfloat16) + primals_4 = rand_strided((2, 1, 16, 16), (256, 256, 16, 1), device='cuda:6', dtype=torch.int32) + primals_5 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:6', dtype=torch.int32) + primals_6 = rand_strided((2, ), (1, ), device='cuda:6', dtype=torch.int64) + primals_7 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:6', dtype=torch.int32) + primals_8 = rand_strided((2, 1, 16, 16), (256, 256, 16, 1), device='cuda:6', dtype=torch.int32) + primals_9 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:6', dtype=torch.int32) + primals_10 = rand_strided((2, 1, 16, 16), (256, 256, 16, 1), device='cuda:6', dtype=torch.int32) + primals_11 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:6', dtype=torch.int32) + primals_12 = rand_strided((2, 1, 16, 16), (256, 256, 16, 1), device='cuda:6', dtype=torch.int32) + getitem = rand_strided((2, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:6', dtype=torch.bfloat16) + getitem_1 = rand_strided((2, 32, 2048), (65536, 2048, 1), device='cuda:6', dtype=torch.float32) + tangents_1 = rand_strided((2, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:6', dtype=torch.bfloat16) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, getitem, getitem_1, tangents_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/ti/a8152bfbf99523e09b33ed3877cd4dc93fc0827a07c7ddca8371990989a5c6ef.best_config b/SpecForge-ext/cache/compiled_kernels/ti/a8152bfbf99523e09b33ed3877cd4dc93fc0827a07c7ddca8371990989a5c6ef.best_config new file mode 100644 index 0000000000000000000000000000000000000000..2a95815b49cfc301dd2a3d06bb1b105b04bfbae7 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/ti/a8152bfbf99523e09b33ed3877cd4dc93fc0827a07c7ddca8371990989a5c6ef.best_config @@ -0,0 +1 @@ +{"XBLOCK": 256, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 11, "triton_cache_hash": "XAIV2GWX5UZL7NNOCKNWC2I6AATKI6664P6FTQPRXS2M4AR4WJWA"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/to/ctogfjlnz7aechutp6s6dodeisr6uojmljrw6bxgwh74bxvmvf6w.py b/SpecForge-ext/cache/compiled_kernels/to/ctogfjlnz7aechutp6s6dodeisr6uojmljrw6bxgwh74bxvmvf6w.py new file mode 100644 index 0000000000000000000000000000000000000000..cf30c0e9896218c192cd26dc003582d1ebeae120 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/to/ctogfjlnz7aechutp6s6dodeisr6uojmljrw6bxgwh74bxvmvf6w.py @@ -0,0 +1,1083 @@ +# AOT ID: ['13_backward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wp/cwpp2ogxi4ziv4d6g6hohwssakk6mbdlaj4nklq5voabubccx3d6.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] +# Source node to ATen node mapping: +# Graph fragment: +# %getitem : Tensor "bf16[2, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:5" = PlaceHolder[target=getitem] +# %tangents_1 : Tensor "bf16[2, 32, s37, 128][4096*Max(1, s37), 128*Max(1, s37), 128, 1]cuda:5" = PlaceHolder[target=tangents_1] +# %buf0 : Tensor "bf16[2, 32, s37][32*s37, s37, 1]cuda:5" = PlaceHolder[target=buf0] +# %full_default : Tensor "f32[2, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 32, %primals_10], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:5, pin_memory: False}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_2, %primals_4, %primals_6, %getitem, %getitem_1, %tangents_1, %full_default, %fw_graph0, %joint_graph0, (%primals_10, %primals_11, %primals_13, %primals_9, %primals_17, %primals_20, %primals_22, %primals_25, %primals_27, %primals_30, 128, 128, %mask_graph0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_14, %primals_15)), kwargs = {}) +# return %buf0,%buf1 +triton_red_fused_zeros_0 = async_compile.triton('triton_red_fused_zeros_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 131072, 'r0_': 128}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr1': '*fp32', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_zeros_0(in_ptr0, in_ptr1, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % ks0) + x1 = ((xindex // ks0) % 32) + x2 = xindex // ks1 + x5 = triton_helpers.div_floor_integer(xindex, ks0) + _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + x4 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_3 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_3 + 128*x1 + 4096*x0 + 4096*ks0*x2), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_3 + 128*x0 + 128*x5*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = _tmp4 + tmp3 + _tmp4 = tl.where(r0_mask & xmask, tmp5, _tmp4) + tmp4 = tl.sum(_tmp4, 1)[:, None] + tmp6 = tmp4.to(tl.float32) + tmp7 = 0.0 + tmp8 = tmp6 - tmp7 + tl.store(out_ptr1 + (x4), tmp8, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/es/ces6igirh7ild5cxrl3jkv5ib25midu6s5yyh2tqsvbm3cwwomwg.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] +# Source node to ATen node mapping: +# Graph fragment: +# %primals_2 : Tensor "bf16[2, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:5" = PlaceHolder[target=primals_2] +# %primals_4 : Tensor "bf16[2, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:5" = PlaceHolder[target=primals_4] +# %primals_6 : Tensor "bf16[2, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:5" = PlaceHolder[target=primals_6] +# %getitem_1 : Tensor "f32[2, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:5" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[2, 32, s37][32*s37, s37, 1]cuda:5" = PlaceHolder[target=buf1] +# %tangents_1 : Tensor "bf16[2, 32, s37, 128][4096*Max(1, s37), 128*Max(1, s37), 128, 1]cuda:5" = PlaceHolder[target=tangents_1] +# %getitem_3 : Tensor "bf16[2, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:5" = PlaceHolder[target=getitem_3] +# %getitem_5 : Tensor "bf16[2, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:5" = PlaceHolder[target=getitem_5] +# %primals_13 : Tensor "i32[2, 1, s99][s99, s99, 1]cuda:5" = PlaceHolder[target=primals_13] +# %primals_9 : Tensor "i32[2, 1, s22, s72][s22*s72, s22*s72, s72, 1]cuda:5" = PlaceHolder[target=primals_9] +# %primals_22 : Tensor "i32[2, 1, s56][s56, s56, 1]cuda:5" = PlaceHolder[target=primals_22] +# %primals_25 : Tensor "i32[2, 1, s84, s53][s53*s84, s53*s84, s53, 1]cuda:5" = PlaceHolder[target=primals_25] +# %primals_17 : Tensor "i32[2, 1, s94][s94, s94, 1]cuda:5" = PlaceHolder[target=primals_17] +# %primals_20 : Tensor "i32[2, 1, s28, s4][s28*s4, s28*s4, s4, 1]cuda:5" = PlaceHolder[target=primals_20] +# %primals_27 : Tensor "i32[2, 1, s100][s100, s100, 1]cuda:5" = PlaceHolder[target=primals_27] +# %primals_30 : Tensor "i32[2, 1, s6, s10][s10*s6, s10*s6, s10, 1]cuda:5" = PlaceHolder[target=primals_30] +# %primals_14 : Tensor "i64[2][1]cuda:5" = PlaceHolder[target=primals_14] +# %full_default : Tensor "f32[2, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 32, %primals_10], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:5, pin_memory: False}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_2, %primals_4, %primals_6, %getitem, %getitem_1, %tangents_1, %full_default, %fw_graph0, %joint_graph0, (%primals_10, %primals_11, %primals_13, %primals_9, %primals_17, %primals_20, %primals_22, %primals_25, %primals_27, %primals_30, 128, 128, %mask_graph0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_14, %primals_15)), kwargs = {}) +# return %getitem_4 +triton_tem_fused_zeros_1 = async_compile.triton('triton_tem_fused_zeros_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32', 'ks6': 'i32', 'ks7': 'i32', 'ks8': 'i32'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_zeros_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128*ks1, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128*ks1, 128, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128*ks1, 128, 1 + + ZQ = 2 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 2 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 2 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = ks5 + stride_q_idx_h = ks6*ks7 + stride_q_idx_n = ks6 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr16 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = ks8 + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp39 = (ds) + grad_scores = tmp39 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp40 = (qkT) + post_mod_scores = tmp40 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp41 = tl.full([1], False, tl.int1) + tmp42 = (m) + tmp43 = (n) + tmp44 = tmp42 >= tmp43 + tmp45 = tmp43.to(tl.int64) + tmp46 = (off_z) + tmp47 = tl.load(in_ptr16 + tmp46) + tmp48 = tmp45 < tmp47 + tmp49 = tmp42.to(tl.int64) + tmp50 = tmp49 < tmp47 + tmp51 = tmp48 & tmp50 + tmp52 = tmp44 & tmp51 + tmp53 = tmp41 | tmp52 + tmp54 = ks8 + tmp55 = tmp43 >= tmp54 + tmp56 = (tmp43 % tmp54) + tmp57 = tl.full([1], 0, tl.int32) + tmp58 = tmp56 != tmp57 + tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0 + tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0 + tmp61 = tmp59 != tmp60 + tmp62 = tmp58 & tmp61 + tmp63 = tmp56 + tmp54 + tmp64 = tl.where(tmp62, tmp63, tmp56) + tmp65 = tmp64.to(tl.int64) + tmp66 = tmp65 < tmp47 + tmp67 = tmp55 & tmp66 + tmp68 = tmp43 - tmp42 + tmp69 = (tmp68 % tmp54) + tmp70 = tmp69 != tmp57 + tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0 + tmp72 = tmp71 != tmp60 + tmp73 = tmp70 & tmp72 + tmp74 = tmp69 + tmp54 + tmp75 = tl.where(tmp73, tmp74, tmp69) + tmp76 = tmp75 == tmp57 + tmp77 = tmp67 & tmp76 + tmp78 = tmp53 | tmp77 + mask_mod_output = tmp78 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp79 = (dsT) + grad_scores = tmp79 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_10, primals_11, primals_15, primals_7, primals_8, primals_12, primals_16, primals_18, primals_19, primals_21, primals_24, primals_23, primals_26, primals_29, primals_28, primals_2, primals_4, primals_6, primals_9, primals_13, primals_14, primals_17, primals_20, primals_22, primals_25, primals_27, primals_30, getitem, getitem_1, tangents_1 = args + args.clear() + s37 = primals_10 + s0 = primals_11 + s75 = primals_15 + s22 = primals_7 + s72 = primals_8 + s99 = primals_12 + s94 = primals_16 + s28 = primals_18 + s4 = primals_19 + s56 = primals_21 + s53 = primals_24 + s84 = primals_23 + s100 = primals_26 + s10 = primals_29 + s6 = primals_28 + assert_size_stride(primals_2, (2, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(primals_4, (2, 8, s0, 128), (1024*s0, 128*s0, 128, 1)) + assert_size_stride(primals_6, (2, 8, s0, 128), (1024*s0, 128*s0, 128, 1)) + assert_size_stride(primals_9, (2, 1, s22, s72), (s22*s72, s22*s72, s72, 1)) + assert_size_stride(primals_13, (2, 1, s99), (s99, s99, 1)) + assert_size_stride(primals_14, (2, ), (1, )) + assert_size_stride(primals_17, (2, 1, s94), (s94, s94, 1)) + assert_size_stride(primals_20, (2, 1, s28, s4), (s28*s4, s28*s4, s4, 1)) + assert_size_stride(primals_22, (2, 1, s56), (s56, s56, 1)) + assert_size_stride(primals_25, (2, 1, s84, s53), (s53*s84, s53*s84, s53, 1)) + assert_size_stride(primals_27, (2, 1, s100), (s100, s100, 1)) + assert_size_stride(primals_30, (2, 1, s6, s10), (s10*s6, s10*s6, s10, 1)) + assert_size_stride(getitem, (2, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(getitem_1, (2, 32, s37), (32*max(1, s37), max(1, s37), 1)) + assert_size_stride(tangents_1, (2, 32, s37, 128), (4096*max(1, s37), 128*max(1, s37), 128, 1)) + with torch.cuda._DeviceGuard(5): + torch.cuda.set_device(5) + ps0 = 32*s37 + buf1 = empty_strided_cuda((2, 32, s37), (32*s37, s37, 1), torch.float32) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] + triton_red_fused_zeros_0_xnumel = 64*s37 + stream5 = get_raw_stream(5) + triton_red_fused_zeros_0.run(getitem, tangents_1, buf1, s37, ps0, triton_red_fused_zeros_0_xnumel, 128, stream=stream5) + del getitem + buf3 = empty_strided_cuda((2, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16) + buf4 = empty_strided_cuda((2, 8, s0, 128), (1024*s0, 128*s0, 128, 1), torch.bfloat16) + buf5 = empty_strided_cuda((2, 8, s0, 128), (1024*s0, 128*s0, 128, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] + stream5 = get_raw_stream(5) + triton_tem_fused_zeros_1.run(primals_2, primals_4, primals_6, getitem_1, buf1, tangents_1, buf3, buf4, primals_13, primals_9, primals_22, primals_25, primals_17, primals_20, primals_27, primals_30, primals_14, buf5, s37, s0, s99, s22, s72, s56, s53, s84, s75, 4*((127 + s37) // 128) + ((127 + s0) // 128), 2, 8, stream=stream5) + del buf1 + del getitem_1 + del primals_13 + del primals_14 + del primals_17 + del primals_2 + del primals_20 + del primals_22 + del primals_25 + del primals_27 + del primals_30 + del primals_4 + del primals_6 + del primals_9 + del tangents_1 + return (None, buf3, None, buf5, None, buf4, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_10 = 1569 + primals_11 = 1569 + primals_15 = 1569 + primals_7 = 13 + primals_8 = 13 + primals_12 = 13 + primals_16 = 13 + primals_18 = 13 + primals_19 = 13 + primals_21 = 13 + primals_24 = 13 + primals_23 = 13 + primals_26 = 13 + primals_29 = 13 + primals_28 = 13 + primals_2 = rand_strided((2, 32, 1569, 128), (6426624, 128, 4096, 1), device='cuda:5', dtype=torch.bfloat16) + primals_4 = rand_strided((2, 8, 1569, 128), (1606656, 200832, 128, 1), device='cuda:5', dtype=torch.bfloat16) + primals_6 = rand_strided((2, 8, 1569, 128), (1606656, 200832, 128, 1), device='cuda:5', dtype=torch.bfloat16) + primals_9 = rand_strided((2, 1, 13, 13), (169, 169, 13, 1), device='cuda:5', dtype=torch.int32) + primals_13 = rand_strided((2, 1, 13), (13, 13, 1), device='cuda:5', dtype=torch.int32) + primals_14 = rand_strided((2, ), (1, ), device='cuda:5', dtype=torch.int64) + primals_17 = rand_strided((2, 1, 13), (13, 13, 1), device='cuda:5', dtype=torch.int32) + primals_20 = rand_strided((2, 1, 13, 13), (169, 169, 13, 1), device='cuda:5', dtype=torch.int32) + primals_22 = rand_strided((2, 1, 13), (13, 13, 1), device='cuda:5', dtype=torch.int32) + primals_25 = rand_strided((2, 1, 13, 13), (169, 169, 13, 1), device='cuda:5', dtype=torch.int32) + primals_27 = rand_strided((2, 1, 13), (13, 13, 1), device='cuda:5', dtype=torch.int32) + primals_30 = rand_strided((2, 1, 13, 13), (169, 169, 13, 1), device='cuda:5', dtype=torch.int32) + getitem = rand_strided((2, 32, 1569, 128), (6426624, 128, 4096, 1), device='cuda:5', dtype=torch.bfloat16) + getitem_1 = rand_strided((2, 32, 1569), (50208, 1569, 1), device='cuda:5', dtype=torch.float32) + tangents_1 = rand_strided((2, 32, 1569, 128), (6426624, 200832, 128, 1), device='cuda:5', dtype=torch.bfloat16) + fn = lambda: call([primals_10, primals_11, primals_15, primals_7, primals_8, primals_12, primals_16, primals_18, primals_19, primals_21, primals_24, primals_23, primals_26, primals_29, primals_28, primals_2, primals_4, primals_6, primals_9, primals_13, primals_14, primals_17, primals_20, primals_22, primals_25, primals_27, primals_30, getitem, getitem_1, tangents_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/u6/cu6zj2ebnyrrymk64iekh3sw5x54t46cfjhj5gjoqkk4c3twoioy.py b/SpecForge-ext/cache/compiled_kernels/u6/cu6zj2ebnyrrymk64iekh3sw5x54t46cfjhj5gjoqkk4c3twoioy.py new file mode 100644 index 0000000000000000000000000000000000000000..5a18eb48251d26db4eced9906887adb33ac921f3 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/u6/cu6zj2ebnyrrymk64iekh3sw5x54t46cfjhj5gjoqkk4c3twoioy.py @@ -0,0 +1,352 @@ +# AOT ID: ['14_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/42/c42h7visn4guss7swxj4up2er4ije4hyno7yrughuvurnenh2pvd.py +# Topologically Sorted Source Nodes: [argmax], Original ATen: [aten.argmax] +# Source node to ATen node mapping: +# argmax => argmax +# Graph fragment: +# %arg1_1 : Tensor "bf16[2, s3, 32000][32000*s3, 32000, 1]cuda:6" = PlaceHolder[target=arg1_1] +# %argmax : Tensor "i64[2, s3][s3, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.argmax.default](args = (%arg1_1, -1), kwargs = {}) +# return %argmax +triton_red_fused_argmax_0 = async_compile.triton('triton_red_fused_argmax_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_argmax_0(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask & xmask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask & xmask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tl.store(out_ptr0 + (x0), tmp2, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/57/c57svaeo74ael4oxqveudfvhx4xfmu3ikrmvljcnixb4kiqagrzn.py +# Topologically Sorted Source Nodes: [argmax_1], Original ATen: [aten.argmax] +# Source node to ATen node mapping: +# argmax_1 => argmax_1 +# Graph fragment: +# %arg3_1 : Tensor "f32[2, s3, 32000][s71, 32000, 1]cuda:6" = PlaceHolder[target=arg3_1] +# %argmax_1 : Tensor "i64[2, s3][s3, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.argmax.default](args = (%arg3_1, -1), kwargs = {}) +# return %argmax_1 +triton_red_fused_argmax_1 = async_compile.triton('triton_red_fused_argmax_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*i64', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_argmax_1(in_ptr0, out_ptr0, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % ks0) + x1 = xindex // ks0 + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + x3 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_2 + 32000*x0 + ks1*x1), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask & xmask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask & xmask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tl.store(out_ptr0 + (x3), tmp2, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/2v/c2vabblrjzyryauc2jram5kwgwvjexq53bdwxugagjegc2xvufuy.py +# Topologically Sorted Source Nodes: [eq, squeeze, mul, sum_1], Original ATen: [aten.eq, aten.squeeze, aten.mul, aten.sum] +# Source node to ATen node mapping: +# eq => eq_2 +# mul => mul_3 +# squeeze => squeeze +# sum_1 => sum_1 +# Graph fragment: +# %argmax : Tensor "i64[2, s3][s3, 1]cuda:6" = PlaceHolder[target=argmax] +# %argmax_1 : Tensor "i64[2, s3][s3, 1]cuda:6" = PlaceHolder[target=argmax_1] +# %arg4_1 : Tensor "i64[2, s3, 1][s3, 1, 1]cuda:6" = PlaceHolder[target=arg4_1] +# %eq_2 : Tensor "b8[2, s3][s3, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.eq.Tensor](args = (%argmax, %argmax_1), kwargs = {}) +# %squeeze : Tensor "i64[2, s3][s3, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%arg4_1, -1), kwargs = {}) +# %mul_3 : Tensor "i64[2, s3][s3, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%eq_2, %squeeze), kwargs = {}) +# %sum_1 : Tensor "i64[][]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_3,), kwargs = {}) +# return %sum_1 +triton_red_fused_eq_mul_squeeze_sum_2 = async_compile.triton('triton_red_fused_eq_mul_squeeze_sum_2', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 1, 'r0_': 4096}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i64', 'in_ptr2': '*i64', 'out_ptr0': '*i64', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {'xnumel': 1}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_eq_mul_squeeze_sum_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_eq_mul_squeeze_sum_2(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 1 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + _tmp7 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_0 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.load(in_ptr1 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp4 = tl.load(in_ptr2 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp2 = tmp0 == tmp1 + tmp3 = tmp2.to(tl.int64) + tmp5 = tmp3 * tmp4 + tmp6 = tl.broadcast_to(tmp5, [XBLOCK, R0_BLOCK]) + tmp8 = _tmp7 + tmp6 + _tmp7 = tl.where(r0_mask, tmp8, _tmp7) + tmp7 = tl.sum(_tmp7, 1)[:, None] + tl.store(out_ptr0 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp7, None) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/64/c64aw6tdhc533bdu4lu2kexzu7e3rgjk5xeentmkjen77ksnc56t.py +# Topologically Sorted Source Nodes: [sum_2, clamp_min, truediv], Original ATen: [aten.sum, aten.clamp_min, aten.div] +# Source node to ATen node mapping: +# clamp_min => clamp_min +# sum_2 => sum_2 +# truediv => div +# Graph fragment: +# %arg6_1 : Tensor "i64[2, s14, 1][s14, 1, 1]cuda:6" = PlaceHolder[target=arg6_1] +# %sum_1 : Tensor "i64[][]cuda:6" = PlaceHolder[target=sum_1] +# %sum_2 : Tensor "i64[][]cuda:6" = PlaceHolder[target=sum_2] +# %sum_2 : Tensor "i64[][]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%arg6_1,), kwargs = {}) +# %clamp_min : Tensor "f32[][]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.clamp_min.default](args = (%sum_2, 1e-06), kwargs = {}) +# %div : Tensor "f32[][]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%sum_1, %clamp_min), kwargs = {}) +# return %sum_2,%div +triton_red_fused_clamp_min_div_sum_3 = async_compile.triton('triton_red_fused_clamp_min_div_sum_3', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 1, 'r0_': 4096}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i64', 'out_ptr1': '*fp32', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {'xnumel': 1}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_clamp_min_div_sum_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_clamp_min_div_sum_3(in_ptr0, in_ptr1, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 1 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + _tmp2 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_0 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp3 = _tmp2 + tmp1 + _tmp2 = tl.where(r0_mask, tmp3, _tmp2) + tmp2 = tl.sum(_tmp2, 1)[:, None] + tmp4 = tl.load(in_ptr1 + (0)) + tmp5 = tl.broadcast_to(tmp4, [XBLOCK, 1]) + tmp6 = tmp5.to(tl.float32) + tmp7 = tmp2.to(tl.float32) + tmp8 = 1e-06 + tmp9 = triton_helpers.maximum(tmp7, tmp8) + tmp10 = (tmp6 / tmp9) + tl.store(out_ptr1 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp10, None) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1 = args + args.clear() + s3 = arg0_1 + s71 = arg2_1 + s14 = arg5_1 + assert_size_stride(arg1_1, (2, s3, 32000), (32000*s3, 32000, 1)) + assert_size_stride(arg3_1, (2, s3, 32000), (s71, 32000, 1)) + assert_size_stride(arg4_1, (2, s3, 1), (s3, 1, 1)) + assert_size_stride(arg6_1, (2, s14, 1), (s14, 1, 1)) + with torch.cuda._DeviceGuard(6): + torch.cuda.set_device(6) + buf0 = empty_strided_cuda((2, s3), (s3, 1), torch.int64) + # Topologically Sorted Source Nodes: [argmax], Original ATen: [aten.argmax] + triton_red_fused_argmax_0_xnumel = 2*s3 + stream6 = get_raw_stream(6) + triton_red_fused_argmax_0.run(arg1_1, buf0, triton_red_fused_argmax_0_xnumel, 32000, stream=stream6) + del arg1_1 + buf1 = empty_strided_cuda((2, s3), (s3, 1), torch.int64) + # Topologically Sorted Source Nodes: [argmax_1], Original ATen: [aten.argmax] + triton_red_fused_argmax_1_xnumel = 2*s3 + stream6 = get_raw_stream(6) + triton_red_fused_argmax_1.run(arg3_1, buf1, s3, s71, triton_red_fused_argmax_1_xnumel, 32000, stream=stream6) + del arg3_1 + buf2 = empty_strided_cuda((), (), torch.int64) + # Topologically Sorted Source Nodes: [eq, squeeze, mul, sum_1], Original ATen: [aten.eq, aten.squeeze, aten.mul, aten.sum] + triton_red_fused_eq_mul_squeeze_sum_2_r0_numel = 2*s3 + stream6 = get_raw_stream(6) + triton_red_fused_eq_mul_squeeze_sum_2.run(buf0, buf1, arg4_1, buf2, 1, triton_red_fused_eq_mul_squeeze_sum_2_r0_numel, stream=stream6) + del arg4_1 + del buf0 + del buf1 + buf4 = empty_strided_cuda((), (), torch.float32) + # Topologically Sorted Source Nodes: [sum_2, clamp_min, truediv], Original ATen: [aten.sum, aten.clamp_min, aten.div] + triton_red_fused_clamp_min_div_sum_3_r0_numel = 2*s14 + stream6 = get_raw_stream(6) + triton_red_fused_clamp_min_div_sum_3.run(arg6_1, buf2, buf4, 1, triton_red_fused_clamp_min_div_sum_3_r0_numel, stream=stream6) + del arg6_1 + del buf2 + return (buf4, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = 1488 + arg1_1 = rand_strided((2, 1488, 32000), (47616000, 32000, 1), device='cuda:6', dtype=torch.bfloat16) + arg2_1 = 47840000 + arg3_1 = rand_strided((2, 1488, 32000), (47840000, 32000, 1), device='cuda:6', dtype=torch.float32) + arg4_1 = rand_strided((2, 1488, 1), (1488, 1, 1), device='cuda:6', dtype=torch.int64) + arg5_1 = 1488 + arg6_1 = rand_strided((2, 1488, 1), (1488, 1, 1), device='cuda:6', dtype=torch.int64) + fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/uo/cuoflv2dnlrrnqtf32tqz435pzpw3hvrmamzms4siwsvxeljqc7k.py b/SpecForge-ext/cache/compiled_kernels/uo/cuoflv2dnlrrnqtf32tqz435pzpw3hvrmamzms4siwsvxeljqc7k.py new file mode 100644 index 0000000000000000000000000000000000000000..2dc0f881e99608211d9a070d51a448e17cc8c930 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/uo/cuoflv2dnlrrnqtf32tqz435pzpw3hvrmamzms4siwsvxeljqc7k.py @@ -0,0 +1,41 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 2, 'r0_': 8192}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_sum_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 8, 'r0_': 131072}} +) +@triton.jit +def triton_red_fused_sum_3(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 2 + r0_numel = 8192 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp2 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 8192*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp3 = _tmp2 + tmp1 + _tmp2 = tl.where(r0_mask & xmask, tmp3, _tmp2) + tmp2 = tl.sum(_tmp2, 1)[:, None] + tl.store(out_ptr0 + (x0), tmp2, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/ut/cutp3chhk5c6s5fxb2gqzhrx5hjq4ltt3ybguoemttw3toknshg6.py b/SpecForge-ext/cache/compiled_kernels/ut/cutp3chhk5c6s5fxb2gqzhrx5hjq4ltt3ybguoemttw3toknshg6.py new file mode 100644 index 0000000000000000000000000000000000000000..fe315958d167fe0bf1b2a6cd35e09ce3a9b7a2bd --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/ut/cutp3chhk5c6s5fxb2gqzhrx5hjq4ltt3ybguoemttw3toknshg6.py @@ -0,0 +1,37 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 4096, 'r0_': 32}, + reduction_hint=ReductionHint.OUTER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*bf16', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_mul_sum_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused__to_copy_mul_sum_1(in_ptr0, out_ptr0, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 32 + R0_BLOCK: tl.constexpr = 32 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + x0 = xindex + tmp0 = tl.load(in_ptr0 + (x0 + ks0*r0_1), xmask, other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp3 = tl.where(xmask, tmp1, 0) + tmp4 = tl.sum(tmp3, 1)[:, None].to(tl.float32) + tl.store(out_ptr0 + (x0), tmp4, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/uw/cuw54qt3adv2mqfssnwpmkprl2vhbjhxprmh3bb5537pl6evv4j7.py b/SpecForge-ext/cache/compiled_kernels/uw/cuw54qt3adv2mqfssnwpmkprl2vhbjhxprmh3bb5537pl6evv4j7.py new file mode 100644 index 0000000000000000000000000000000000000000..b86d65dca31dd96af29d380637a7e0b94dd5b9ad --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/uw/cuw54qt3adv2mqfssnwpmkprl2vhbjhxprmh3bb5537pl6evv4j7.py @@ -0,0 +1,44 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 1, 'r0_': 4096}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i64', 'in_ptr2': '*i64', 'out_ptr0': '*i64', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {'xnumel': 1}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_eq_mul_squeeze_sum_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_eq_mul_squeeze_sum_2(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 1 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + _tmp7 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_0 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.load(in_ptr1 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp4 = tl.load(in_ptr2 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp2 = tmp0 == tmp1 + tmp3 = tmp2.to(tl.int64) + tmp5 = tmp3 * tmp4 + tmp6 = tl.broadcast_to(tmp5, [XBLOCK, R0_BLOCK]) + tmp8 = _tmp7 + tmp6 + _tmp7 = tl.where(r0_mask, tmp8, _tmp7) + tmp7 = tl.sum(_tmp7, 1)[:, None] + tl.store(out_ptr0 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp7, None) diff --git a/SpecForge-ext/cache/compiled_kernels/v6/cv6zby6rpuyg246nnmlustgtoid6plqf4t2w45xg2axk333ati55.py b/SpecForge-ext/cache/compiled_kernels/v6/cv6zby6rpuyg246nnmlustgtoid6plqf4t2w45xg2axk333ati55.py new file mode 100644 index 0000000000000000000000000000000000000000..001f39bbb9cf741ef607f1adf41a2a7e39e796d9 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/v6/cv6zby6rpuyg246nnmlustgtoid6plqf4t2w45xg2axk333ati55.py @@ -0,0 +1,57 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 16384, 'r0_': 262144}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_out_ptr0': '*i64', 'in_ptr0': '*bf16', 'in_ptr1': '*i1', 'in_ptr2': '*i64', 'xnumel': 'i64', 'r0_numel': 'i64', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 16384 + r0_numel = 151936 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0).to(tl.int64) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None].to(tl.int64) + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :].to(tl.int64) + rbase = r0_base + x0 = xindex + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 9223372036854775807, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 151936*x0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tmp11 = tl.load(in_ptr2 + (x0), None, eviction_policy='evict_last') + tmp3 = tl.full([XBLOCK, 1], 151936, tl.int32) + tmp4 = tmp2 + tmp3 + tmp5 = tmp2 < 0 + tmp6 = tl.where(tmp5, tmp4, tmp2) + tl.device_assert((0 <= tmp6) & (tmp6 < 151936), "index out of bounds: 0 <= tmp6 < 151936") + tmp8 = tl.load(in_ptr1 + (tmp6), None, eviction_policy='evict_last').to(tl.int1) + tmp9 = tmp8.to(tl.int32) + tmp10 = tmp9.to(tl.int64) + tmp12 = tmp10 * tmp11 + tl.debug_barrier() + tl.store(in_out_ptr0 + (x0), tmp12, None) diff --git a/SpecForge-ext/cache/compiled_kernels/vk/cvk2qr7hggrizog6osippdtnv4g54aa5mwpdaz7y5pik3awumasg.py b/SpecForge-ext/cache/compiled_kernels/vk/cvk2qr7hggrizog6osippdtnv4g54aa5mwpdaz7y5pik3awumasg.py new file mode 100644 index 0000000000000000000000000000000000000000..0597dc34c18e662011c4e3fcd0c71cc3c3799ae5 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/vk/cvk2qr7hggrizog6osippdtnv4g54aa5mwpdaz7y5pik3awumasg.py @@ -0,0 +1,46 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 16384, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 262144, 'r0_': 1048576000}} +) +@triton.jit +def triton_red_fused_argmax_0(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 16384 + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tl.store(out_ptr0 + (x0), tmp2, None) diff --git a/SpecForge-ext/cache/compiled_kernels/vk/cvkcbju4ftdjugozv3aumhlgwacbn2h4ae4bwnnofexgmrt5upru.py b/SpecForge-ext/cache/compiled_kernels/vk/cvkcbju4ftdjugozv3aumhlgwacbn2h4ae4bwnnofexgmrt5upru.py new file mode 100644 index 0000000000000000000000000000000000000000..ebf990aec06452eff991ce15084306c448d91554 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/vk/cvkcbju4ftdjugozv3aumhlgwacbn2h4ae4bwnnofexgmrt5upru.py @@ -0,0 +1,62 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 16384, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr2': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0(in_ptr0, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp3_max = tl.full([XBLOCK, R0_BLOCK], float('-inf'), tl.float32) + _tmp3_sum = tl.zeros([XBLOCK, R0_BLOCK], tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp1 = tmp0.to(tl.float32) + tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) + + _tmp3_max_next, _tmp3_sum_next = triton_helpers.online_softmax_combine( + _tmp3_max, _tmp3_sum, tmp2, False + ) + + _tmp3_max = tl.where(r0_mask & xmask, _tmp3_max_next, _tmp3_max) + _tmp3_sum = tl.where(r0_mask & xmask, _tmp3_sum_next, _tmp3_sum) + + tmp3, tmp4 = triton_helpers.online_softmax_reduce( + _tmp3_max, _tmp3_sum, 1, False) + tmp3 = tmp3[:, None] + tmp4 = tmp4[:, None] + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp5 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp6 = tmp5.to(tl.float32) + tmp7 = tmp6 - tmp3 + tmp8 = libdevice.exp(tmp7) + tmp9 = (tmp8 / tmp4) + tl.store(out_ptr2 + (r0_1 + 32000*x0), tmp9, r0_mask & xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/vk/cvkx62l2jl3ze5txwk2jhv7ld6m52dbjnbi6pzjn4mq3t6pawfk3.py b/SpecForge-ext/cache/compiled_kernels/vk/cvkx62l2jl3ze5txwk2jhv7ld6m52dbjnbi6pzjn4mq3t6pawfk3.py new file mode 100644 index 0000000000000000000000000000000000000000..d7e1996c5d34ea9e8226b848ddc780377a9ea1c7 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/vk/cvkx62l2jl3ze5txwk2jhv7ld6m52dbjnbi6pzjn4mq3t6pawfk3.py @@ -0,0 +1,320 @@ +# AOT ID: ['4_backward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3t/c3t6fb3abpggxvyu2nmh2a6rfcx73pzqtlq42b2wdr3nbjf6cyni.py +# Topologically Sorted Source Nodes: [squeeze_2, sin, getitem_1, sin_1, squeeze, cos, getitem, cos_1], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.slice_backward, aten.add] +# Source node to ATen node mapping: +# cos => squeeze_1 +# cos_1 => unsqueeze +# getitem => index +# getitem_1 => index_1 +# sin => squeeze_3 +# sin_1 => unsqueeze_1 +# squeeze => squeeze +# squeeze_2 => squeeze_2 +# Graph fragment: +# %tangents_2 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24*s9, s24, 1]cuda:5" = PlaceHolder[target=tangents_2] +# %primals_8 : Tensor "i64[1, s9][s9, 1]cuda:5" = PlaceHolder[target=primals_8] +# %primals_6 : Tensor "bf16[1, 1, s79, s24][s96, s96, s24, 1]cuda:5" = PlaceHolder[target=primals_6] +# %primals_4 : Tensor "bf16[1, 1, s92, s24][s96, s96, s24, 1]cuda:5" = PlaceHolder[target=primals_4] +# %squeeze_2 : Tensor "bf16[1, s79, s24][s96, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_6, 1), kwargs = {}) +# %squeeze_3 : Tensor "bf16[s79, s24][s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze_2, 0), kwargs = {}) +# %index_1 : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_3, [%primals_8]), kwargs = {}) +# %unsqueeze_1 : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:5"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index_1, 1), kwargs = {}) +# %mul_84 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24*s9, s24, 1]cuda:5"[num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, %unsqueeze_1), kwargs = {}) +# %slice_5 : Tensor "bf16[s48, s48, s9, s24 - ((s24//2))][s24*s48*s9, s24*s9, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%mul_84, 3, 0, %add_96), kwargs = {}) +# %slice_6 : Tensor "bf16[s48, s48, s9, (s24//2)][s24*s48*s9, s24*s9, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%mul_84, 3, %sub_72, %primals_2), kwargs = {}) +# %neg_2 : Tensor "bf16[s48, s48, s9, s24 - ((s24//2))][s48*s9*Max(1, s24 - ((s24//2))), s9*Max(1, s24 - ((s24//2))), Max(1, s24 - ((s24//2))), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%slice_5,), kwargs = {}) +# %full_default : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24*s9, s24, 1]cuda:5"[num_users=2] = call_function[target=torch.ops.aten.full.default](args = ([%primals_10, %primals_10, %primals_7, %primals_2], 0), kwargs = {dtype: torch.bfloat16, layout: torch.strided, device: cuda:5, pin_memory: False}) +# %slice_scatter_default : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24*s9, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%full_default, %neg_2, 3, %floordiv, 9223372036854775807), kwargs = {}) +# %slice_scatter_default_1 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24*s9, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%full_default, %slice_6, 3, 0, %floordiv), kwargs = {}) +# %add_100 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24*s9, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%slice_scatter_default, %slice_scatter_default_1), kwargs = {}) +# %squeeze : Tensor "bf16[1, s92, s24][s96, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_4, 1), kwargs = {}) +# %squeeze_1 : Tensor "bf16[s92, s24][s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze, 0), kwargs = {}) +# %index : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_1, [%primals_8]), kwargs = {}) +# %unsqueeze : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:5"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index, 1), kwargs = {}) +# %mul_85 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24*s9, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, %unsqueeze), kwargs = {}) +# %add_101 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24*s9, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_100, %mul_85), kwargs = {}) +# return %add_101 +triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0 = async_compile.triton('triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 16777216}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 6, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = (xindex % ks0) + x3 = xindex + x1 = ((xindex // ks0) % ks1) + tmp31 = tl.load(in_ptr0 + (x3), xmask, eviction_policy='evict_last').to(tl.float32) + tmp32 = tl.load(in_ptr1 + (x1), xmask, eviction_policy='evict_last') + tmp0 = x0 + tmp1 = ks0 // 2 + tmp2 = tmp0 >= tmp1 + tmp3 = tl.load(in_ptr0 + (x3 + (-1)*(ks0 // 2)), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp4 = tl.load(in_ptr1 + (x1), tmp2 & xmask, eviction_policy='evict_last', other=0.0) + tmp5 = tl.broadcast_to(ks2, [XBLOCK]) + tmp6 = tmp4 + tmp5 + tmp7 = tmp4 < 0 + tmp8 = tl.where(tmp7, tmp6, tmp4) + tl.device_assert(((0 <= tl.broadcast_to(tmp8, [XBLOCK])) & (tl.broadcast_to(tmp8, [XBLOCK]) < ks2)) | ~(tmp2 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp8, [XBLOCK]) < ks2") + tmp10 = tl.load(in_ptr2 + (x0 + (-1)*(ks0 // 2) + ks0*tmp8), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp11 = tmp3 * tmp10 + tmp12 = -tmp11 + tmp13 = tl.full(tmp12.shape, 0.0, tmp12.dtype) + tmp14 = tl.where(tmp2, tmp12, tmp13) + tmp15 = 0.0 + tmp16 = tl.where(tmp2, tmp14, tmp15) + tmp17 = tmp0 < tmp1 + tmp18 = tl.load(in_ptr0 + (ks0 + x3 + (-1)*(ks0 // 2)), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp19 = tl.load(in_ptr1 + (x1), tmp17 & xmask, eviction_policy='evict_last', other=0.0) + tmp20 = tl.broadcast_to(ks2, [XBLOCK]) + tmp21 = tmp19 + tmp20 + tmp22 = tmp19 < 0 + tmp23 = tl.where(tmp22, tmp21, tmp19) + tl.device_assert(((0 <= tl.broadcast_to(tmp23, [XBLOCK])) & (tl.broadcast_to(tmp23, [XBLOCK]) < ks2)) | ~(tmp17 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp23, [XBLOCK]) < ks2") + tmp25 = tl.load(in_ptr2 + (ks0 + x0 + (-1)*(ks0 // 2) + ks0*tmp23), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp26 = tmp18 * tmp25 + tmp27 = tl.full(tmp26.shape, 0.0, tmp26.dtype) + tmp28 = tl.where(tmp17, tmp26, tmp27) + tmp29 = tl.where(tmp17, tmp28, tmp15) + tmp30 = tmp16 + tmp29 + tmp33 = ks3 + tmp34 = tmp32 + tmp33 + tmp35 = tmp32 < 0 + tmp36 = tl.where(tmp35, tmp34, tmp32) + tl.device_assert(((0 <= tmp36) & (tmp36 < ks3)) | ~(xmask), "index out of bounds: 0 <= tmp36 < ks3") + tmp38 = tl.load(in_ptr3 + (x0 + ks0*tmp36), xmask, eviction_policy='evict_last').to(tl.float32) + tmp39 = tmp31 * tmp38 + tmp40 = tmp30 + tmp39 + tl.store(out_ptr0 + (x3), tmp40, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/w3/cw3xkrzc4xrvq5kejf7ygu6gb5dutkng3ftrzniiui5s4f3mkqat.py +# Topologically Sorted Source Nodes: [squeeze_2, sin, getitem_1, sin_1, squeeze, cos, getitem, cos_1], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.slice_backward, aten.add] +# Source node to ATen node mapping: +# cos => squeeze_1 +# cos_1 => unsqueeze +# getitem => index +# getitem_1 => index_1 +# sin => squeeze_3 +# sin_1 => unsqueeze_1 +# squeeze => squeeze +# squeeze_2 => squeeze_2 +# Graph fragment: +# %tangents_1 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:5" = PlaceHolder[target=tangents_1] +# %primals_8 : Tensor "i64[1, s9][s9, 1]cuda:5" = PlaceHolder[target=primals_8] +# %primals_6 : Tensor "bf16[1, 1, s79, s24][s96, s96, s24, 1]cuda:5" = PlaceHolder[target=primals_6] +# %primals_4 : Tensor "bf16[1, 1, s92, s24][s96, s96, s24, 1]cuda:5" = PlaceHolder[target=primals_4] +# %squeeze_2 : Tensor "bf16[1, s79, s24][s96, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_6, 1), kwargs = {}) +# %squeeze_3 : Tensor "bf16[s79, s24][s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze_2, 0), kwargs = {}) +# %index_1 : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_3, [%primals_8]), kwargs = {}) +# %unsqueeze_1 : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:5"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index_1, 1), kwargs = {}) +# %squeeze : Tensor "bf16[1, s92, s24][s96, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_4, 1), kwargs = {}) +# %squeeze_1 : Tensor "bf16[s92, s24][s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze, 0), kwargs = {}) +# %index : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_1, [%primals_8]), kwargs = {}) +# %unsqueeze : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:5"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index, 1), kwargs = {}) +# %mul_86 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:5"[num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_1, %unsqueeze_1), kwargs = {}) +# %slice_7 : Tensor "bf16[s48, s34, s9, s24 - ((s24//2))][s24*s34*s9, s24*s9, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%mul_86, 3, 0, %sub_72), kwargs = {}) +# %slice_8 : Tensor "bf16[s48, s34, s9, (s24//2)][s24*s34*s9, s24*s9, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%mul_86, 3, %sub_72, %primals_2), kwargs = {}) +# %neg_3 : Tensor "bf16[s48, s34, s9, s24 - ((s24//2))][s34*s9*Max(1, s24 - ((s24//2))), s9*Max(1, s24 - ((s24//2))), Max(1, s24 - ((s24//2))), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%slice_7,), kwargs = {}) +# %full_default_2 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:5"[num_users=2] = call_function[target=torch.ops.aten.full.default](args = ([%primals_10, %primals_11, %primals_7, %primals_2], 0), kwargs = {dtype: torch.bfloat16, layout: torch.strided, device: cuda:5, pin_memory: False}) +# %slice_scatter_default_2 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%full_default_2, %neg_3, 3, %floordiv, 9223372036854775807), kwargs = {}) +# %slice_scatter_default_3 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%full_default_2, %slice_8, 3, 0, %floordiv), kwargs = {}) +# %add_106 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%slice_scatter_default_2, %slice_scatter_default_3), kwargs = {}) +# %mul_87 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_1, %unsqueeze), kwargs = {}) +# %add_107 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_106, %mul_87), kwargs = {}) +# return %add_107 +triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1 = async_compile.triton('triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 67108864}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 6, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = (xindex % ks0) + x3 = xindex + x1 = ((xindex // ks0) % ks1) + tmp31 = tl.load(in_ptr0 + (x3), xmask, eviction_policy='evict_last').to(tl.float32) + tmp32 = tl.load(in_ptr1 + (x1), xmask, eviction_policy='evict_last') + tmp0 = x0 + tmp1 = ks0 // 2 + tmp2 = tmp0 >= tmp1 + tmp3 = tl.load(in_ptr0 + (x3 + (-1)*(ks0 // 2)), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp4 = tl.load(in_ptr1 + (x1), tmp2 & xmask, eviction_policy='evict_last', other=0.0) + tmp5 = tl.broadcast_to(ks2, [XBLOCK]) + tmp6 = tmp4 + tmp5 + tmp7 = tmp4 < 0 + tmp8 = tl.where(tmp7, tmp6, tmp4) + tl.device_assert(((0 <= tl.broadcast_to(tmp8, [XBLOCK])) & (tl.broadcast_to(tmp8, [XBLOCK]) < ks2)) | ~(tmp2 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp8, [XBLOCK]) < ks2") + tmp10 = tl.load(in_ptr2 + (x0 + (-1)*(ks0 // 2) + ks0*tmp8), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp11 = tmp3 * tmp10 + tmp12 = -tmp11 + tmp13 = tl.full(tmp12.shape, 0.0, tmp12.dtype) + tmp14 = tl.where(tmp2, tmp12, tmp13) + tmp15 = 0.0 + tmp16 = tl.where(tmp2, tmp14, tmp15) + tmp17 = tmp0 < tmp1 + tmp18 = tl.load(in_ptr0 + (ks0 + x3 + (-1)*(ks0 // 2)), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp19 = tl.load(in_ptr1 + (x1), tmp17 & xmask, eviction_policy='evict_last', other=0.0) + tmp20 = tl.broadcast_to(ks2, [XBLOCK]) + tmp21 = tmp19 + tmp20 + tmp22 = tmp19 < 0 + tmp23 = tl.where(tmp22, tmp21, tmp19) + tl.device_assert(((0 <= tl.broadcast_to(tmp23, [XBLOCK])) & (tl.broadcast_to(tmp23, [XBLOCK]) < ks2)) | ~(tmp17 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp23, [XBLOCK]) < ks2") + tmp25 = tl.load(in_ptr2 + (ks0 + x0 + (-1)*(ks0 // 2) + ks0*tmp23), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp26 = tmp18 * tmp25 + tmp27 = tl.full(tmp26.shape, 0.0, tmp26.dtype) + tmp28 = tl.where(tmp17, tmp26, tmp27) + tmp29 = tl.where(tmp17, tmp28, tmp15) + tmp30 = tmp16 + tmp29 + tmp33 = ks3 + tmp34 = tmp32 + tmp33 + tmp35 = tmp32 < 0 + tmp36 = tl.where(tmp35, tmp34, tmp32) + tl.device_assert(((0 <= tmp36) & (tmp36 < ks3)) | ~(xmask), "index out of bounds: 0 <= tmp36 < ks3") + tmp38 = tl.load(in_ptr3 + (x0 + ks0*tmp36), xmask, eviction_policy='evict_last').to(tl.float32) + tmp39 = tmp31 * tmp38 + tmp40 = tmp30 + tmp39 + tl.store(out_ptr0 + (x3), tmp40, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_2, primals_7, primals_10, primals_11, primals_1, primals_3, primals_5, floordiv, add_96, primals_4, primals_6, primals_8, tangents_1, tangents_2 = args + args.clear() + s24 = primals_2 + s9 = primals_7 + s48 = primals_10 + s34 = primals_11 + s92 = primals_1 + s96 = primals_3 + s79 = primals_5 + assert_size_stride(primals_4, (1, 1, s92, s24), (s96, s96, s24, 1)) + assert_size_stride(primals_6, (1, 1, s79, s24), (s96, s96, s24, 1)) + assert_size_stride(primals_8, (1, s9), (s9, 1)) + assert_size_stride(tangents_1, (s48, s34, s9, s24), (s24*s34*s9, s24*s9, s24, 1)) + assert_size_stride(tangents_2, (s48, s48, s9, s24), (s24*s48*s9, s24*s9, s24, 1)) + with torch.cuda._DeviceGuard(5): + torch.cuda.set_device(5) + buf0 = empty_strided_cuda((s48, s48, s9, s24), (s24*s48*s9, s24*s9, s24, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [squeeze_2, sin, getitem_1, sin_1, squeeze, cos, getitem, cos_1], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.slice_backward, aten.add] + triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0_xnumel = s24*s9*s48*s48 + stream5 = get_raw_stream(5) + triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0.run(tangents_2, primals_8, primals_6, primals_4, buf0, s24, s9, s79, s92, triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0_xnumel, stream=stream5) + del tangents_2 + buf1 = empty_strided_cuda((s48, s34, s9, s24), (s24*s34*s9, s24*s9, s24, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [squeeze_2, sin, getitem_1, sin_1, squeeze, cos, getitem, cos_1], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.slice_backward, aten.add] + triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1_xnumel = s24*s34*s48*s9 + stream5 = get_raw_stream(5) + triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1.run(tangents_1, primals_8, primals_6, primals_4, buf1, s24, s9, s79, s92, triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1_xnumel, stream=stream5) + del primals_4 + del primals_6 + del primals_8 + del tangents_1 + return (None, None, None, None, None, None, None, None, None, None, None, buf1, buf0, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_2 = 128 + primals_7 = 2048 + primals_10 = 8 + primals_11 = 32 + primals_1 = 2048 + primals_3 = 5245440 + primals_5 = 2048 + floordiv = 64 + add_96 = 64 + primals_4 = rand_strided((1, 1, 2048, 128), (5245440, 5245440, 128, 1), device='cuda:5', dtype=torch.bfloat16) + primals_6 = rand_strided((1, 1, 2048, 128), (5245440, 5245440, 128, 1), device='cuda:5', dtype=torch.bfloat16) + primals_8 = rand_strided((1, 2048), (2048, 1), device='cuda:5', dtype=torch.int64) + tangents_1 = rand_strided((8, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:5', dtype=torch.bfloat16) + tangents_2 = rand_strided((8, 8, 2048, 128), (2097152, 262144, 128, 1), device='cuda:5', dtype=torch.bfloat16) + fn = lambda: call([primals_2, primals_7, primals_10, primals_11, primals_1, primals_3, primals_5, floordiv, add_96, primals_4, primals_6, primals_8, tangents_1, tangents_2]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/vn/cvn736oi3yulmpiv2kyhznjdsmsi3u35zxuqvuyabq7sna42w72l.py b/SpecForge-ext/cache/compiled_kernels/vn/cvn736oi3yulmpiv2kyhznjdsmsi3u35zxuqvuyabq7sna42w72l.py new file mode 100644 index 0000000000000000000000000000000000000000..68b804506081bc15c632302356a219ab8e28f1c2 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/vn/cvn736oi3yulmpiv2kyhznjdsmsi3u35zxuqvuyabq7sna42w72l.py @@ -0,0 +1,552 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'in_ptr9': '*i64', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 2097152, 262144, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 2097152, 262144, 128, 1 + + ZQ = 2 + HQ = 32 + Q_LEN = 2048 + ZKV = 2 + KV_LEN = 2048 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 2 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 256 + stride_kv_idx_m = 16 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 262144*idx_hq + 8388608*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m + 8388608*idx_zq, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr9 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/vs/cvsjb3xujy534tq24hswfuky4funbyak56lp7i6o6rheqa22cc5p.py b/SpecForge-ext/cache/compiled_kernels/vs/cvsjb3xujy534tq24hswfuky4funbyak56lp7i6o6rheqa22cc5p.py new file mode 100644 index 0000000000000000000000000000000000000000..f44fd5ff4ff52490706193c230ac2a291d17f222 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/vs/cvsjb3xujy534tq24hswfuky4funbyak56lp7i6o6rheqa22cc5p.py @@ -0,0 +1,835 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32', 'ks6': 'i32', 'ks7': 'i32', 'ks8': 'i32'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_zeros_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128*ks1, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128*ks1, 128, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128*ks1, 128, 1 + + ZQ = 2 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 2 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 2 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = ks5 + stride_q_idx_h = ks6*ks7 + stride_q_idx_n = ks6 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr16 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = ks8 + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp39 = (ds) + grad_scores = tmp39 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp40 = (qkT) + post_mod_scores = tmp40 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp41 = tl.full([1], False, tl.int1) + tmp42 = (m) + tmp43 = (n) + tmp44 = tmp42 >= tmp43 + tmp45 = tmp43.to(tl.int64) + tmp46 = (off_z) + tmp47 = tl.load(in_ptr16 + tmp46) + tmp48 = tmp45 < tmp47 + tmp49 = tmp42.to(tl.int64) + tmp50 = tmp49 < tmp47 + tmp51 = tmp48 & tmp50 + tmp52 = tmp44 & tmp51 + tmp53 = tmp41 | tmp52 + tmp54 = ks8 + tmp55 = tmp43 >= tmp54 + tmp56 = (tmp43 % tmp54) + tmp57 = tl.full([1], 0, tl.int32) + tmp58 = tmp56 != tmp57 + tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0 + tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0 + tmp61 = tmp59 != tmp60 + tmp62 = tmp58 & tmp61 + tmp63 = tmp56 + tmp54 + tmp64 = tl.where(tmp62, tmp63, tmp56) + tmp65 = tmp64.to(tl.int64) + tmp66 = tmp65 < tmp47 + tmp67 = tmp55 & tmp66 + tmp68 = tmp43 - tmp42 + tmp69 = (tmp68 % tmp54) + tmp70 = tmp69 != tmp57 + tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0 + tmp72 = tmp71 != tmp60 + tmp73 = tmp70 & tmp72 + tmp74 = tmp69 + tmp54 + tmp75 = tl.where(tmp73, tmp74, tmp69) + tmp76 = tmp75 == tmp57 + tmp77 = tmp67 & tmp76 + tmp78 = tmp53 | tmp77 + mask_mod_output = tmp78 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp79 = (dsT) + grad_scores = tmp79 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) diff --git a/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py b/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py new file mode 100644 index 0000000000000000000000000000000000000000..d7b3f71de4235393aa6c5f15b2c09341affa29ce --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py @@ -0,0 +1,43 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 32, 'r0_': 16}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr1': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_clone_slice_sum_transpose_5', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_clone_slice_sum_transpose_5(in_ptr0, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % ks0) + x1 = xindex // ks0 + _tmp3 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + x3 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (x0 + ks0*r0_2 + ks0*ks1*x1), r0_mask & xmask, eviction_policy='evict_last', other=0.0) + tmp1 = tmp0.to(tl.int64) + tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) + tmp4 = _tmp3 + tmp2 + _tmp3 = tl.where(r0_mask & xmask, tmp4, _tmp3) + tmp3 = tl.sum(_tmp3, 1)[:, None] + tmp5 = tmp3.to(tl.int32) + tl.store(out_ptr1 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp5, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/wc/cwcw6e7zfjg335twcczs7lhvjsmpvg4pyoo7npy5nddyaqv4hfzb.py b/SpecForge-ext/cache/compiled_kernels/wc/cwcw6e7zfjg335twcczs7lhvjsmpvg4pyoo7npy5nddyaqv4hfzb.py new file mode 100644 index 0000000000000000000000000000000000000000..0320a9a900c645ca13d46d2fe0124f807deb6108 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/wc/cwcw6e7zfjg335twcczs7lhvjsmpvg4pyoo7npy5nddyaqv4hfzb.py @@ -0,0 +1,164 @@ +# AOT ID: ['0_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ns/cnsrpjlovgxpe5uxsjdupxszbn2v5ie4attqhksar2iayjruaiwi.py +# Topologically Sorted Source Nodes: [target_max_token, target_mask, getitem_1, target_mask_1, position_mask], Original ATen: [aten.argmax, aten.index, aten.unsqueeze, aten._to_copy, aten.mul] +# Source node to ATen node mapping: +# getitem_1 => unsqueeze +# position_mask => mul +# target_mask => index +# target_mask_1 => convert_element_type +# target_max_token => argmax +# Graph fragment: +# %arg0_1 : Tensor "bf16[2, 2048, 151936][311164928, 151936, 1]cuda:0" = PlaceHolder[target=arg0_1] +# %argmax : Tensor "i64[2, 2048][2048, 1]cuda:0" = PlaceHolder[target=argmax] +# %arg1_1 : Tensor "b8[151936][1]cuda:0" = PlaceHolder[target=arg1_1] +# %arg2_1 : Tensor "i64[2, 2048, 1][2048, 1, 1]cuda:0" = PlaceHolder[target=arg2_1] +# %argmax : Tensor "i64[2, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.argmax.default](args = (%arg0_1, -1), kwargs = {}) +# %index : Tensor "b8[2, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg1_1, [%argmax]), kwargs = {}) +# %unsqueeze : Tensor "b8[2, 2048, 1][2048, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index, 2), kwargs = {}) +# %convert_element_type : Tensor "i32[2, 2048, 1][2048, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%unsqueeze, torch.int32), kwargs = {}) +# %mul : Tensor "i64[2, 2048, 1][2048, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type, %arg2_1), kwargs = {}) +# return %argmax,%mul +triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0 = async_compile.triton('triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 262144}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_out_ptr0': '*i64', 'in_ptr0': '*bf16', 'in_ptr1': '*i1', 'in_ptr2': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 4096 + r0_numel = 151936 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 151936*x0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tmp11 = tl.load(in_ptr2 + (x0), None, eviction_policy='evict_last') + tmp3 = tl.full([XBLOCK, 1], 151936, tl.int32) + tmp4 = tmp2 + tmp3 + tmp5 = tmp2 < 0 + tmp6 = tl.where(tmp5, tmp4, tmp2) + tl.device_assert((0 <= tmp6) & (tmp6 < 151936), "index out of bounds: 0 <= tmp6 < 151936") + tmp8 = tl.load(in_ptr1 + (tmp6), None, eviction_policy='evict_last').to(tl.int1) + tmp9 = tmp8.to(tl.int32) + tmp10 = tmp9.to(tl.int64) + tmp12 = tmp10 * tmp11 + tl.debug_barrier() + tl.store(in_out_ptr0 + (x0), tmp12, None) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1, arg2_1 = args + args.clear() + assert_size_stride(arg0_1, (2, 2048, 151936), (311164928, 151936, 1)) + assert_size_stride(arg1_1, (151936, ), (1, )) + assert_size_stride(arg2_1, (2, 2048, 1), (2048, 1, 1)) + with torch.cuda._DeviceGuard(0): + torch.cuda.set_device(0) + buf0 = empty_strided_cuda((2, 2048), (2048, 1), torch.int64) + buf1 = reinterpret_tensor(buf0, (2, 2048, 1), (2048, 1, 1), 0); del buf0 # reuse + # Topologically Sorted Source Nodes: [target_max_token, target_mask, getitem_1, target_mask_1, position_mask], Original ATen: [aten.argmax, aten.index, aten.unsqueeze, aten._to_copy, aten.mul] + stream0 = get_raw_stream(0) + triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0.run(buf1, arg0_1, arg1_1, arg2_1, 4096, 151936, stream=stream0) + del arg0_1 + del arg1_1 + del arg2_1 + return (buf1, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = rand_strided((2, 2048, 151936), (311164928, 151936, 1), device='cuda:0', dtype=torch.bfloat16) + arg1_1 = rand_strided((151936, ), (1, ), device='cuda:0', dtype=torch.bool) + arg2_1 = rand_strided((2, 2048, 1), (2048, 1, 1), device='cuda:0', dtype=torch.int64) + fn = lambda: call([arg0_1, arg1_1, arg2_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/wc/f8d7a470e8bde42e927c0d84733e5ff4aa16ac944909d292071d2bb1acd048d9.best_config b/SpecForge-ext/cache/compiled_kernels/wc/f8d7a470e8bde42e927c0d84733e5ff4aa16ac944909d292071d2bb1acd048d9.best_config new file mode 100644 index 0000000000000000000000000000000000000000..ae7f45fa18a6ff5e69d13d9b7fbac7dc07677004 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/wc/f8d7a470e8bde42e927c0d84733e5ff4aa16ac944909d292071d2bb1acd048d9.best_config @@ -0,0 +1 @@ +{"XBLOCK": 32, "R0_BLOCK": 16, "num_warps": 4, "num_stages": 1, "configs_hash": "21ad1ee516cd6d15e1fb8e88c10082cd54bef654f8a281c7d5ccd54b6509a685", "found_by_coordesc": false, "time_taken_ms": 29, "triton_cache_hash": "2HBOMUT44J5WFCUWYGRFAAS3HGVNDHLHT7HCSXUCAOIKU6XGJNTA"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/wf/1e276d7586d12688f1a837a69cefa10254d3399cb702696e2d63fd0d3b422291.best_config b/SpecForge-ext/cache/compiled_kernels/wf/1e276d7586d12688f1a837a69cefa10254d3399cb702696e2d63fd0d3b422291.best_config new file mode 100644 index 0000000000000000000000000000000000000000..7d56ea7451f6ff3ceffec392bc015b86ab20533e --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/wf/1e276d7586d12688f1a837a69cefa10254d3399cb702696e2d63fd0d3b422291.best_config @@ -0,0 +1 @@ +{"XBLOCK": 256, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 11, "triton_cache_hash": "4UWYNBR3KPWQGNAZ5LIIRE7YAZWTQP4CP3JS6GOSLWYDF5K7WTAA"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/wf/b1411a221e106b7a93616bbb971d336e179aadfae5e785dbe21148f6ffe923bb.best_config b/SpecForge-ext/cache/compiled_kernels/wf/b1411a221e106b7a93616bbb971d336e179aadfae5e785dbe21148f6ffe923bb.best_config new file mode 100644 index 0000000000000000000000000000000000000000..a570e8d663ff6e600f50df05a811c859065ec3c4 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/wf/b1411a221e106b7a93616bbb971d336e179aadfae5e785dbe21148f6ffe923bb.best_config @@ -0,0 +1 @@ +{"XBLOCK": 512, "num_warps": 8, "num_stages": 1, "configs_hash": "3ca5c3e34d35093f3c9ab2829a9faeebad5e61c4ca13d5ed6053d7b71ce60d5a", "found_by_coordesc": false, "time_taken_ms": 21, "triton_cache_hash": "Z2RWAHMO7VUWQKIIRA5A46JYV2SEXHWLKREQM7TOP6VGUWDXAYAQ"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/wf/cwf6bkjfjzdctzdez7j3aj4ebefwcsqlz4gci5drec5pqysdr7fn.py b/SpecForge-ext/cache/compiled_kernels/wf/cwf6bkjfjzdctzdez7j3aj4ebefwcsqlz4gci5drec5pqysdr7fn.py new file mode 100644 index 0000000000000000000000000000000000000000..9d2956d055ee570102135e3fb8aa848f50e63f0b --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/wf/cwf6bkjfjzdctzdez7j3aj4ebefwcsqlz4gci5drec5pqysdr7fn.py @@ -0,0 +1,56 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 4194304}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'ks4': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, ks4, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x4 = xindex + x2 = ((xindex // ks0) % ks1) + x0 = (xindex % ks3) + x5 = xindex // ks3 + tmp0 = tl.load(in_ptr0 + (x4), xmask, eviction_policy='evict_last').to(tl.float32) + tmp1 = tl.load(in_ptr1 + (x2), xmask, eviction_policy='evict_last') + tmp2 = ks2 + tmp3 = tmp1 + tmp2 + tmp4 = tmp1 < 0 + tmp5 = tl.where(tmp4, tmp3, tmp1) + tl.device_assert(((0 <= tmp5) & (tmp5 < ks2)) | ~(xmask), "index out of bounds: 0 <= tmp5 < ks2") + tmp7 = tl.load(in_ptr2 + (x0 + ks3*tmp5), xmask, eviction_policy='evict_last').to(tl.float32) + tmp8 = tmp0 * tmp7 + tmp9 = x0 + tmp10 = tl.full([1], 0, tl.int64) + tmp11 = tmp9 >= tmp10 + tmp12 = ks3 + (-1)*(ks3 // 2) + tmp13 = tmp9 < tmp12 + tmp14 = tl.load(in_ptr0 + (ks3*x5 + (ks3 // 2) + (x0)), tmp13 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp15 = -tmp14 + tmp16 = tl.full(tmp15.shape, 0.0, tmp15.dtype) + tmp17 = tl.where(tmp13, tmp15, tmp16) + tmp18 = tmp9 >= tmp12 + tmp19 = ks3 + tmp20 = tmp9 < tmp19 + tmp21 = tl.load(in_ptr0 + (ks3*x5 + (x0 + ((-1)*ks3) + (ks3 // 2))), tmp18 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp22 = tl.where(tmp13, tmp17, tmp21) + tmp23 = ks4 + tmp24 = tmp1 + tmp23 + tmp25 = tl.where(tmp4, tmp24, tmp1) + tl.device_assert(((0 <= tmp25) & (tmp25 < ks4)) | ~(xmask), "index out of bounds: 0 <= tmp25 < ks4") + tmp27 = tl.load(in_ptr3 + (x0 + ks3*tmp25), xmask, eviction_policy='evict_last').to(tl.float32) + tmp28 = tmp22 * tmp27 + tmp29 = tmp8 + tmp28 + tl.store(out_ptr0 + (x4), tmp29, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/wf/cwfe7sfhj752wancob47q72uij65kfje36vvsyqvwyq7aed4zfns.py b/SpecForge-ext/cache/compiled_kernels/wf/cwfe7sfhj752wancob47q72uij65kfje36vvsyqvwyq7aed4zfns.py new file mode 100644 index 0000000000000000000000000000000000000000..69af5b8339bdf3b56f3f7b5050e75bda74838a9a --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/wf/cwfe7sfhj752wancob47q72uij65kfje36vvsyqvwyq7aed4zfns.py @@ -0,0 +1,835 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention_backward(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 2097152, 262144, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 2097152, 262144, 128, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 8388608, 262144, 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 8388608, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 2097152, 262144, 128, 1 + + ZQ = 8 + HQ = 32 + HKV = 8 + Q_LEN = 2048 + ZKV = 8 + KV_LEN = 2048 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 8 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 256 + stride_kv_idx_m = 16 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 16 + stride_q_idx_h = 256 + stride_q_idx_n = 16 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 262144*off_hkv + 2097152*off_zq + tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = 2048 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr16 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp39 = (ds) + grad_scores = tmp39 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = 2048 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp40 = (qkT) + post_mod_scores = tmp40 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp41 = tl.full([1], False, tl.int1) + tmp42 = (m) + tmp43 = (n) + tmp44 = tmp42 >= tmp43 + tmp45 = tmp43.to(tl.int64) + tmp46 = (off_z) + tmp47 = tl.load(in_ptr16 + tmp46) + tmp48 = tmp45 < tmp47 + tmp49 = tmp42.to(tl.int64) + tmp50 = tmp49 < tmp47 + tmp51 = tmp48 & tmp50 + tmp52 = tmp44 & tmp51 + tmp53 = tmp41 | tmp52 + tmp54 = tl.full([1], 2048, tl.int32) + tmp55 = tmp43 >= tmp54 + tmp56 = (tmp43 % tmp54) + tmp57 = tl.full([1], 0, tl.int32) + tmp58 = tmp56 != tmp57 + tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0 + tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0 + tmp61 = tmp59 != tmp60 + tmp62 = tmp58 & tmp61 + tmp63 = tmp56 + tmp54 + tmp64 = tl.where(tmp62, tmp63, tmp56) + tmp65 = tmp64.to(tl.int64) + tmp66 = tmp65 < tmp47 + tmp67 = tmp55 & tmp66 + tmp68 = tmp43 - tmp42 + tmp69 = (tmp68 % tmp54) + tmp70 = tmp69 != tmp57 + tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0 + tmp72 = tmp71 != tmp60 + tmp73 = tmp70 & tmp72 + tmp74 = tmp69 + tmp54 + tmp75 = tl.where(tmp73, tmp74, tmp69) + tmp76 = tmp75 == tmp57 + tmp77 = tmp67 & tmp76 + tmp78 = tmp53 | tmp77 + mask_mod_output = tmp78 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp79 = (dsT) + grad_scores = tmp79 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/wf/cwfyyn6luby7jkq6pqhiqb44jz2jln72mtaagrllnlfp5opls7qm.py b/SpecForge-ext/cache/compiled_kernels/wf/cwfyyn6luby7jkq6pqhiqb44jz2jln72mtaagrllnlfp5opls7qm.py new file mode 100644 index 0000000000000000000000000000000000000000..000f34c1ca658935ecdd0b83c5298f5e2e77d345 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/wf/cwfyyn6luby7jkq6pqhiqb44jz2jln72mtaagrllnlfp5opls7qm.py @@ -0,0 +1,25 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 4096}, + filename=__file__, + triton_meta={'signature': {'out_ptr0': '*i32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_new_zeros_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 17408}}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_new_zeros_1(out_ptr0, xnumel, XBLOCK : tl.constexpr): + xnumel = 2176 + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.full([1], 0, tl.int32) + tl.store(out_ptr0 + (x0), tmp0, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/wk/58ebbee78c1cb2658910e222cb57287349628b9e6a51fcafd38720ef4450877b.best_config b/SpecForge-ext/cache/compiled_kernels/wk/58ebbee78c1cb2658910e222cb57287349628b9e6a51fcafd38720ef4450877b.best_config new file mode 100644 index 0000000000000000000000000000000000000000..2a95815b49cfc301dd2a3d06bb1b105b04bfbae7 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/wk/58ebbee78c1cb2658910e222cb57287349628b9e6a51fcafd38720ef4450877b.best_config @@ -0,0 +1 @@ +{"XBLOCK": 256, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 11, "triton_cache_hash": "XAIV2GWX5UZL7NNOCKNWC2I6AATKI6664P6FTQPRXS2M4AR4WJWA"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/wk/cwk6qcbhztyvjyrdotkvyjvochdzirkan7xywtdlfzreg6ld74ts.py b/SpecForge-ext/cache/compiled_kernels/wk/cwk6qcbhztyvjyrdotkvyjvochdzirkan7xywtdlfzreg6ld74ts.py new file mode 100644 index 0000000000000000000000000000000000000000..18738ec3dca91bdf3cff0185e482e241cc6788fe --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/wk/cwk6qcbhztyvjyrdotkvyjvochdzirkan7xywtdlfzreg6ld74ts.py @@ -0,0 +1,62 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 32, 'r0_': 32}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'in_ptr1': '*i64', 'out_ptr1': '*i32', 'out_ptr2': '*i32', 'out_ptr3': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2', 'mutated_arg_names': ['out_ptr3'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2(in_ptr0, in_ptr1, out_ptr1, out_ptr2, out_ptr3, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 32 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp3 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp1 = tmp0.to(tl.int64) + tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) + tmp4 = _tmp3 + tmp2 + _tmp3 = tl.where(r0_mask & xmask, tmp4, _tmp3) + tmp3 = tl.sum(_tmp3, 1)[:, None] + tmp5 = tmp3.to(tl.int32) + tl.store(out_ptr1 + (x0), tmp5, xmask) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp6 = tl.load(in_ptr1 + (r0_1 + x0*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp7 = tmp6.to(tl.int32) + tmp8 = r0_1 + tmp9 = tmp8 < tmp5 + tmp10 = ks0 + tmp11 = tl.where(tmp9, tmp7, tmp10) + tmp12 = 1 + ks0 + tmp13 = tmp11 + tmp12 + tmp14 = tmp11 < 0 + tmp15 = tl.where(tmp14, tmp13, tmp11) + tl.device_assert(((0 <= tmp15) & (tmp15 < 1 + (triton_helpers.div_floor_integer(127 + ks1, 128)))) | ~(r0_mask & xmask), "index out of bounds: 0 <= tmp15 < 1 + (triton_helpers.div_floor_integer(127 + ks1, 128))") + tmp17 = tl.full([1, 1], 1, tl.int32) + tl.store(out_ptr2 + (r0_1 + x0*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp7, r0_mask & xmask) + tl.store(out_ptr3 + (tl.broadcast_to(tmp15 + x0 + ks0*x0, [XBLOCK, R0_BLOCK])), tmp17, r0_mask & xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/wk/cwknascjc5r3t6bclny2fnjjakqaebtac6hrvubwnh2e5yl5qhk3.py b/SpecForge-ext/cache/compiled_kernels/wk/cwknascjc5r3t6bclny2fnjjakqaebtac6hrvubwnh2e5yl5qhk3.py new file mode 100644 index 0000000000000000000000000000000000000000..985b07b53fc1d27fce629abdc89ff5ebf00709d7 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/wk/cwknascjc5r3t6bclny2fnjjakqaebtac6hrvubwnh2e5yl5qhk3.py @@ -0,0 +1,24 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 8192}, + filename=__file__, + triton_meta={'signature': {'out_ptr0': '*i32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_new_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_new_zeros_0(out_ptr0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.full([1], 0, tl.int32) + tl.store(out_ptr0 + (x0), tmp0, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/ws/cwscasifvcovoelwi4d7vjwpz2nhgtvtiz5adyxxcjk5omtuzd4t.py b/SpecForge-ext/cache/compiled_kernels/ws/cwscasifvcovoelwi4d7vjwpz2nhgtvtiz5adyxxcjk5omtuzd4t.py new file mode 100644 index 0000000000000000000000000000000000000000..614d513d08b33c8443ce2a1719b5a04d3c52051a --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/ws/cwscasifvcovoelwi4d7vjwpz2nhgtvtiz5adyxxcjk5omtuzd4t.py @@ -0,0 +1,71 @@ +# AOT ID: ['3_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1, arg2_1 = args + args.clear() + s21 = arg0_1 + assert_size_stride(arg1_1, (1, 1, 40980, 128), (5245440, 5245440, 128, 1)) + assert_size_stride(arg2_1, (1, 1, 40980, 128), (5245440, 5245440, 128, 1)) + return (reinterpret_tensor(arg1_1, (1, 1, s21, 128), (5245440, 5245440, 128, 1), 0), reinterpret_tensor(arg2_1, (1, 1, s21, 128), (5245440, 5245440, 128, 1), 0), ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = 2048 + arg1_1 = rand_strided((1, 1, 40980, 128), (5245440, 5245440, 128, 1), device='cuda:4', dtype=torch.bfloat16) + arg2_1 = rand_strided((1, 1, 40980, 128), (5245440, 5245440, 128, 1), device='cuda:4', dtype=torch.bfloat16) + fn = lambda: call([arg0_1, arg1_1, arg2_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/wy/cwyjr2eky5gz34uakabhhj5hqtafgvo3lud5teff44oxgtaxnc2f.py b/SpecForge-ext/cache/compiled_kernels/wy/cwyjr2eky5gz34uakabhhj5hqtafgvo3lud5teff44oxgtaxnc2f.py new file mode 100644 index 0000000000000000000000000000000000000000..ef34966f46c56d6ec15f92142bc85d0aaef9b434 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/wy/cwyjr2eky5gz34uakabhhj5hqtafgvo3lud5teff44oxgtaxnc2f.py @@ -0,0 +1,46 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 1, 'r0_': 4096}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i64', 'out_ptr1': '*fp32', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {'xnumel': 1}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_clamp_min_div_sum_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_clamp_min_div_sum_3(in_ptr0, in_ptr1, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 1 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + _tmp2 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_0 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp3 = _tmp2 + tmp1 + _tmp2 = tl.where(r0_mask, tmp3, _tmp2) + tmp2 = tl.sum(_tmp2, 1)[:, None] + tmp4 = tl.load(in_ptr1 + (0)) + tmp5 = tl.broadcast_to(tmp4, [XBLOCK, 1]) + tmp6 = tmp5.to(tl.float32) + tmp7 = tmp2.to(tl.float32) + tmp8 = 1e-06 + tmp9 = triton_helpers.maximum(tmp7, tmp8) + tmp10 = (tmp6 / tmp9) + tl.store(out_ptr1 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp10, None) diff --git a/SpecForge-ext/cache/compiled_kernels/x2/cx2edpaed257yr34nr5n6yv7v4hojwvpu6m3g2pnovwimvbfqtuj.py b/SpecForge-ext/cache/compiled_kernels/x2/cx2edpaed257yr34nr5n6yv7v4hojwvpu6m3g2pnovwimvbfqtuj.py new file mode 100644 index 0000000000000000000000000000000000000000..86bfd5e37bef3551f4d22b5f8c4a130d43a59715 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/x2/cx2edpaed257yr34nr5n6yv7v4hojwvpu6m3g2pnovwimvbfqtuj.py @@ -0,0 +1,1051 @@ +# AOT ID: ['6_backward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/m7/cm7u3olama3gox426hxhxixqvzhslez5o7pvi4bnehh2g4ww6k6i.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] +# Source node to ATen node mapping: +# Graph fragment: +# %getitem : Tensor "bf16[8, 32, 2048, 128][8388608, 128, 4096, 1]cuda:1" = PlaceHolder[target=getitem] +# %tangents_1 : Tensor "bf16[8, 32, 2048, 128][8388608, 262144, 128, 1]cuda:1" = PlaceHolder[target=tangents_1] +# %buf0 : Tensor "bf16[8, 32, 2048][65536, 2048, 1]cuda:1" = PlaceHolder[target=buf0] +# %full_default : Tensor "f32[8, 32, 2048][65536, 2048, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 32, 2048], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:1, pin_memory: False}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_1, %primals_2, %primals_3, %getitem, %getitem_1, %tangents_1, %full_default, %fw_graph0, %joint_graph0, (2048, 2048, %primals_5, %primals_4, %primals_7, %primals_8, %primals_9, %primals_10, %primals_11, %primals_12, 128, 128, %mask_graph0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_6,)), kwargs = {}) +# return %buf0,%buf1 +triton_red_fused_zeros_0 = async_compile.triton('triton_red_fused_zeros_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 524288, 'r0_': 128}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr1': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 4194304, 'r0_': 268435456}} +) +@triton.jit +def triton_red_fused_zeros_0(in_ptr0, in_ptr1, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 524288 + r0_numel = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % 2048) + x1 = ((xindex // 2048) % 32) + x2 = xindex // 65536 + x4 = xindex + _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_3 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_3 + 128*x1 + 4096*x0 + 8388608*x2), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_3 + 128*x4), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = _tmp4 + tmp3 + _tmp4 = tl.where(r0_mask, tmp5, _tmp4) + tmp4 = tl.sum(_tmp4, 1)[:, None] + tmp6 = tmp4.to(tl.float32) + tmp7 = 0.0 + tmp8 = tmp6 - tmp7 + tl.store(out_ptr1 + (x4), tmp8, None) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bd/cbdpymknkquuerovirx6corahubfs5khfhys2add2b3c2zkuvlup.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] +# Source node to ATen node mapping: +# Graph fragment: +# %primals_1 : Tensor "bf16[8, 32, 2048, 128][8388608, 128, 4096, 1]cuda:1" = PlaceHolder[target=primals_1] +# %primals_2 : Tensor "bf16[8, 8, 2048, 128][2097152, 262144, 128, 1]cuda:1" = PlaceHolder[target=primals_2] +# %primals_3 : Tensor "bf16[8, 8, 2048, 128][2097152, 262144, 128, 1]cuda:1" = PlaceHolder[target=primals_3] +# %getitem_1 : Tensor "f32[8, 32, 2048][65536, 2048, 1]cuda:1" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[8, 32, 2048][65536, 2048, 1]cuda:1" = PlaceHolder[target=buf1] +# %tangents_1 : Tensor "bf16[8, 32, 2048, 128][8388608, 262144, 128, 1]cuda:1" = PlaceHolder[target=tangents_1] +# %getitem_3 : Tensor "bf16[8, 32, 2048, 128][8388608, 128, 4096, 1]cuda:1" = PlaceHolder[target=getitem_3] +# %getitem_5 : Tensor "bf16[8, 8, 2048, 128][2097152, 262144, 128, 1]cuda:1" = PlaceHolder[target=getitem_5] +# %primals_5 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:1" = PlaceHolder[target=primals_5] +# %primals_4 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:1" = PlaceHolder[target=primals_4] +# %primals_9 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:1" = PlaceHolder[target=primals_9] +# %primals_10 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:1" = PlaceHolder[target=primals_10] +# %primals_7 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:1" = PlaceHolder[target=primals_7] +# %primals_8 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:1" = PlaceHolder[target=primals_8] +# %primals_11 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:1" = PlaceHolder[target=primals_11] +# %primals_12 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:1" = PlaceHolder[target=primals_12] +# %primals_6 : Tensor "i64[8][1]cuda:1" = PlaceHolder[target=primals_6] +# %full_default : Tensor "f32[8, 32, 2048][65536, 2048, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 32, 2048], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:1, pin_memory: False}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_1, %primals_2, %primals_3, %getitem, %getitem_1, %tangents_1, %full_default, %fw_graph0, %joint_graph0, (2048, 2048, %primals_5, %primals_4, %primals_7, %primals_8, %primals_9, %primals_10, %primals_11, %primals_12, 128, 128, %mask_graph0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_6,)), kwargs = {}) +# return %getitem_4 +triton_tem_fused_zeros_1 = async_compile.triton('triton_tem_fused_zeros_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_zeros_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 2097152, 262144, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 2097152, 262144, 128, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 8388608, 262144, 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 8388608, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 2097152, 262144, 128, 1 + + ZQ = 8 + HQ = 32 + HKV = 8 + Q_LEN = 2048 + ZKV = 8 + KV_LEN = 2048 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 8 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 256 + stride_kv_idx_m = 16 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 16 + stride_q_idx_h = 256 + stride_q_idx_n = 16 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 262144*off_hkv + 2097152*off_zq + tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = 2048 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr16 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp39 = (ds) + grad_scores = tmp39 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = 2048 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp40 = (qkT) + post_mod_scores = tmp40 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp41 = tl.full([1], False, tl.int1) + tmp42 = (m) + tmp43 = (n) + tmp44 = tmp42 >= tmp43 + tmp45 = tmp43.to(tl.int64) + tmp46 = (off_z) + tmp47 = tl.load(in_ptr16 + tmp46) + tmp48 = tmp45 < tmp47 + tmp49 = tmp42.to(tl.int64) + tmp50 = tmp49 < tmp47 + tmp51 = tmp48 & tmp50 + tmp52 = tmp44 & tmp51 + tmp53 = tmp41 | tmp52 + tmp54 = tl.full([1], 2048, tl.int32) + tmp55 = tmp43 >= tmp54 + tmp56 = (tmp43 % tmp54) + tmp57 = tl.full([1], 0, tl.int32) + tmp58 = tmp56 != tmp57 + tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0 + tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0 + tmp61 = tmp59 != tmp60 + tmp62 = tmp58 & tmp61 + tmp63 = tmp56 + tmp54 + tmp64 = tl.where(tmp62, tmp63, tmp56) + tmp65 = tmp64.to(tl.int64) + tmp66 = tmp65 < tmp47 + tmp67 = tmp55 & tmp66 + tmp68 = tmp43 - tmp42 + tmp69 = (tmp68 % tmp54) + tmp70 = tmp69 != tmp57 + tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0 + tmp72 = tmp71 != tmp60 + tmp73 = tmp70 & tmp72 + tmp74 = tmp69 + tmp54 + tmp75 = tl.where(tmp73, tmp74, tmp69) + tmp76 = tmp75 == tmp57 + tmp77 = tmp67 & tmp76 + tmp78 = tmp53 | tmp77 + mask_mod_output = tmp78 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp79 = (dsT) + grad_scores = tmp79 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, getitem, getitem_1, tangents_1 = args + args.clear() + assert_size_stride(primals_1, (8, 32, 2048, 128), (8388608, 128, 4096, 1)) + assert_size_stride(primals_2, (8, 8, 2048, 128), (2097152, 262144, 128, 1)) + assert_size_stride(primals_3, (8, 8, 2048, 128), (2097152, 262144, 128, 1)) + assert_size_stride(primals_4, (8, 1, 16, 16), (256, 256, 16, 1)) + assert_size_stride(primals_5, (8, 1, 16), (16, 16, 1)) + assert_size_stride(primals_6, (8, ), (1, )) + assert_size_stride(primals_7, (8, 1, 16), (16, 16, 1)) + assert_size_stride(primals_8, (8, 1, 16, 16), (256, 256, 16, 1)) + assert_size_stride(primals_9, (8, 1, 16), (16, 16, 1)) + assert_size_stride(primals_10, (8, 1, 16, 16), (256, 256, 16, 1)) + assert_size_stride(primals_11, (8, 1, 16), (16, 16, 1)) + assert_size_stride(primals_12, (8, 1, 16, 16), (256, 256, 16, 1)) + assert_size_stride(getitem, (8, 32, 2048, 128), (8388608, 128, 4096, 1)) + assert_size_stride(getitem_1, (8, 32, 2048), (65536, 2048, 1)) + assert_size_stride(tangents_1, (8, 32, 2048, 128), (8388608, 262144, 128, 1)) + with torch.cuda._DeviceGuard(1): + torch.cuda.set_device(1) + buf1 = empty_strided_cuda((8, 32, 2048), (65536, 2048, 1), torch.float32) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] + stream1 = get_raw_stream(1) + triton_red_fused_zeros_0.run(getitem, tangents_1, buf1, 524288, 128, stream=stream1) + del getitem + buf3 = empty_strided_cuda((8, 32, 2048, 128), (8388608, 128, 4096, 1), torch.bfloat16) + buf4 = empty_strided_cuda((8, 8, 2048, 128), (2097152, 262144, 128, 1), torch.bfloat16) + buf5 = empty_strided_cuda((8, 8, 2048, 128), (2097152, 262144, 128, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] + stream1 = get_raw_stream(1) + triton_tem_fused_zeros_1.run(primals_1, primals_2, primals_3, getitem_1, buf1, tangents_1, buf3, buf4, primals_5, primals_4, primals_9, primals_10, primals_7, primals_8, primals_11, primals_12, primals_6, buf5, 80, 8, 8, stream=stream1) + del buf1 + del getitem_1 + del primals_1 + del primals_10 + del primals_11 + del primals_12 + del primals_2 + del primals_3 + del primals_4 + del primals_5 + del primals_6 + del primals_7 + del primals_8 + del primals_9 + del tangents_1 + return (buf3, buf5, buf4, None, None, None, None, None, None, None, None, None, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = rand_strided((8, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:1', dtype=torch.bfloat16) + primals_2 = rand_strided((8, 8, 2048, 128), (2097152, 262144, 128, 1), device='cuda:1', dtype=torch.bfloat16) + primals_3 = rand_strided((8, 8, 2048, 128), (2097152, 262144, 128, 1), device='cuda:1', dtype=torch.bfloat16) + primals_4 = rand_strided((8, 1, 16, 16), (256, 256, 16, 1), device='cuda:1', dtype=torch.int32) + primals_5 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:1', dtype=torch.int32) + primals_6 = rand_strided((8, ), (1, ), device='cuda:1', dtype=torch.int64) + primals_7 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:1', dtype=torch.int32) + primals_8 = rand_strided((8, 1, 16, 16), (256, 256, 16, 1), device='cuda:1', dtype=torch.int32) + primals_9 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:1', dtype=torch.int32) + primals_10 = rand_strided((8, 1, 16, 16), (256, 256, 16, 1), device='cuda:1', dtype=torch.int32) + primals_11 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:1', dtype=torch.int32) + primals_12 = rand_strided((8, 1, 16, 16), (256, 256, 16, 1), device='cuda:1', dtype=torch.int32) + getitem = rand_strided((8, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:1', dtype=torch.bfloat16) + getitem_1 = rand_strided((8, 32, 2048), (65536, 2048, 1), device='cuda:1', dtype=torch.float32) + tangents_1 = rand_strided((8, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:1', dtype=torch.bfloat16) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, getitem, getitem_1, tangents_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/x5/cx5mjkcv5uwwjqfmpxvjsxr77i2vnrar2c7mvcr2sz7r4tshmogg.py b/SpecForge-ext/cache/compiled_kernels/x5/cx5mjkcv5uwwjqfmpxvjsxr77i2vnrar2c7mvcr2sz7r4tshmogg.py new file mode 100644 index 0000000000000000000000000000000000000000..5536629d7cc627f2ba4b64b5cad4bd100fa5eda9 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/x5/cx5mjkcv5uwwjqfmpxvjsxr77i2vnrar2c7mvcr2sz7r4tshmogg.py @@ -0,0 +1,835 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_zeros_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks0, 128*ks0, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks0, 128*ks0, 128, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 8388608, 262144, 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 8388608, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks0, 128*ks0, 128, 1 + + ZQ = 8 + HQ = 32 + HKV = 8 + Q_LEN = 2048 + ZKV = 8 + KV_LEN = ks0 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 8 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 16*ks1 + stride_kv_idx_m = ks1 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = ks2 + stride_q_idx_h = 16*ks3 + stride_q_idx_n = 16 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks0 + 1024*off_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = ks0 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr16 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp39 = (ds) + grad_scores = tmp39 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = ks0 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp40 = (qkT) + post_mod_scores = tmp40 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp41 = tl.full([1], False, tl.int1) + tmp42 = (m) + tmp43 = (n) + tmp44 = tmp42 >= tmp43 + tmp45 = tmp43.to(tl.int64) + tmp46 = (off_z) + tmp47 = tl.load(in_ptr16 + tmp46) + tmp48 = tmp45 < tmp47 + tmp49 = tmp42.to(tl.int64) + tmp50 = tmp49 < tmp47 + tmp51 = tmp48 & tmp50 + tmp52 = tmp44 & tmp51 + tmp53 = tmp41 | tmp52 + tmp54 = tl.full([1], 2048, tl.int32) + tmp55 = tmp43 >= tmp54 + tmp56 = (tmp43 % tmp54) + tmp57 = tl.full([1], 0, tl.int32) + tmp58 = tmp56 != tmp57 + tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0 + tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0 + tmp61 = tmp59 != tmp60 + tmp62 = tmp58 & tmp61 + tmp63 = tmp56 + tmp54 + tmp64 = tl.where(tmp62, tmp63, tmp56) + tmp65 = tmp64.to(tl.int64) + tmp66 = tmp65 < tmp47 + tmp67 = tmp55 & tmp66 + tmp68 = tmp43 - tmp42 + tmp69 = (tmp68 % tmp54) + tmp70 = tmp69 != tmp57 + tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0 + tmp72 = tmp71 != tmp60 + tmp73 = tmp70 & tmp72 + tmp74 = tmp69 + tmp54 + tmp75 = tl.where(tmp73, tmp74, tmp69) + tmp76 = tmp75 == tmp57 + tmp77 = tmp67 & tmp76 + tmp78 = tmp53 | tmp77 + mask_mod_output = tmp78 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp79 = (dsT) + grad_scores = tmp79 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) diff --git a/SpecForge-ext/cache/compiled_kernels/x5/cx5w4h5vj5kxgcwzb4viv5yobmli7guubk3q3mwejfzpzdkfoyb6.py b/SpecForge-ext/cache/compiled_kernels/x5/cx5w4h5vj5kxgcwzb4viv5yobmli7guubk3q3mwejfzpzdkfoyb6.py new file mode 100644 index 0000000000000000000000000000000000000000..414c795eb8531b685b498ae2ebe3b9cfbddae4f1 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/x5/cx5w4h5vj5kxgcwzb4viv5yobmli7guubk3q3mwejfzpzdkfoyb6.py @@ -0,0 +1,159 @@ +# AOT ID: ['1_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/c5/cc5dyv2gy7kqwwgof22mbw3houj3mwz3mpm5wwkls5nzlyig75gr.py +# Topologically Sorted Source Nodes: [target_head, target_p], Original ATen: [aten._to_copy, prims.prepare_softmax_online, aten.sub, aten.exp, aten._softmax] +# Source node to ATen node mapping: +# target_head => convert_element_type +# target_p => div +# Graph fragment: +# %arg0_1 : Tensor "bf16[8, 2048, 32000][65536000, 32000, 1]cuda:2" = PlaceHolder[target=arg0_1] +# %getitem : Tensor "f32[8, 2048, 1][2048, 1, 16384]cuda:2" = PlaceHolder[target=getitem] +# %getitem_1 : Tensor "f32[8, 2048, 1][2048, 1, 16384]cuda:2" = PlaceHolder[target=getitem_1] +# %convert_element_type : Tensor "f32[8, 2048, 32000][65536000, 32000, 1]cuda:2"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%arg0_1, torch.float32), kwargs = {}) +# %prepare_softmax_online_default : [num_users=2] = call_function[target=torch.ops.prims.prepare_softmax_online.default](args = (%convert_element_type, 2), kwargs = {}) +# %sub_tensor : Tensor "f32[8, 2048, 32000][65536000, 32000, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%convert_element_type, %getitem), kwargs = {}) +# %exp_default : Tensor "f32[8, 2048, 32000][65536000, 32000, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.exp.default](args = (%sub_tensor,), kwargs = {}) +# %div : Tensor "f32[8, 2048, 32000][65536000, 32000, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%exp_default, %getitem_1), kwargs = {}) +# return %getitem,%getitem_1,%div +triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0 = async_compile.triton('triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 16384, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr2': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'add_persistent_rblock': True, 'tiling_scores': {'x': 0, 'r0_': 5242880000}} +) +@triton.jit +def triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0(in_ptr0, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 16384 + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp3_max = tl.full([XBLOCK, R0_BLOCK], float('-inf'), tl.float32) + _tmp3_sum = tl.zeros([XBLOCK, R0_BLOCK], tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp1 = tmp0.to(tl.float32) + tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) + + _tmp3_max_next, _tmp3_sum_next = triton_helpers.online_softmax_combine( + _tmp3_max, _tmp3_sum, tmp2, False + ) + + _tmp3_max = tl.where(r0_mask, _tmp3_max_next, _tmp3_max) + _tmp3_sum = tl.where(r0_mask, _tmp3_sum_next, _tmp3_sum) + + tmp3, tmp4 = triton_helpers.online_softmax_reduce( + _tmp3_max, _tmp3_sum, 1, False) + tmp3 = tmp3[:, None] + tmp4 = tmp4[:, None] + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp5 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp6 = tmp5.to(tl.float32) + tmp7 = tmp6 - tmp3 + tmp8 = libdevice.exp(tmp7) + tmp9 = (tmp8 / tmp4) + tl.store(out_ptr2 + (r0_1 + 32000*x0), tmp9, r0_mask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, = args + args.clear() + assert_size_stride(arg0_1, (8, 2048, 32000), (65536000, 32000, 1)) + with torch.cuda._DeviceGuard(2): + torch.cuda.set_device(2) + buf2 = empty_strided_cuda((8, 2048, 32000), (65536000, 32000, 1), torch.float32) + # Topologically Sorted Source Nodes: [target_head, target_p], Original ATen: [aten._to_copy, prims.prepare_softmax_online, aten.sub, aten.exp, aten._softmax] + stream2 = get_raw_stream(2) + triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0.run(arg0_1, buf2, 16384, 32000, stream=stream2) + del arg0_1 + return (buf2, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = rand_strided((8, 2048, 32000), (65536000, 32000, 1), device='cuda:2', dtype=torch.bfloat16) + fn = lambda: call([arg0_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/y2/3541e25303ba7527995bbd5888095e817222c07fd9fd69bb0365420223c67f03.best_config b/SpecForge-ext/cache/compiled_kernels/y2/3541e25303ba7527995bbd5888095e817222c07fd9fd69bb0365420223c67f03.best_config new file mode 100644 index 0000000000000000000000000000000000000000..39aa06f1122c6eb2904338d2578102fd0e126a89 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/y2/3541e25303ba7527995bbd5888095e817222c07fd9fd69bb0365420223c67f03.best_config @@ -0,0 +1 @@ +{"XBLOCK": 1, "num_warps": 2, "num_stages": 1, "configs_hash": "b6ac5ef64fddcad8fc8d2c05fa12424871fd9baa5a4158ff38ecebbafb55a4b1", "found_by_coordesc": false, "time_taken_ms": 26, "triton_cache_hash": "G2LU7LIHIOEHQSWVLFBJATACJ76YHM672CUBUDGJGAJUEQVWVOFQ"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/y2/cy2z3prnenjcb4hrgo5vbmijazodmwndrxkivwmm3zcr5cgu6obg.py b/SpecForge-ext/cache/compiled_kernels/y2/cy2z3prnenjcb4hrgo5vbmijazodmwndrxkivwmm3zcr5cgu6obg.py new file mode 100644 index 0000000000000000000000000000000000000000..79b0879c9aefadedec3950b77afff979c6de2e0f --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/y2/cy2z3prnenjcb4hrgo5vbmijazodmwndrxkivwmm3zcr5cgu6obg.py @@ -0,0 +1,49 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 256, 'r0_': 16}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr2': '*i32', 'out_ptr3': '*i32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3(in_ptr0, out_ptr2, out_ptr3, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 16 + R0_BLOCK: tl.constexpr = 16 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + x0 = (xindex % ks0) + x1 = xindex // ks0 + x3 = xindex + tmp0 = tl.load(in_ptr0 + (r0_2 + x0 + 16*x1 + ks0*r0_2 + 16*ks0*x1), xmask, eviction_policy='evict_last', other=0.0) + tmp1 = r0_2 + tmp2 = tmp1.to(tl.int16) + tmp3 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp4 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5, tmp6, = triton_helpers.sort_with_index(tmp3, tmp4, None, 1, stable=True, descending=True) + tmp7 = tmp0.to(tl.int64) + tmp8 = tl.broadcast_to(tmp7, [XBLOCK, R0_BLOCK]) + tmp10 = tl.where(xmask, tmp8, 0) + tmp11 = tl.sum(tmp10, 1)[:, None].to(tl.int64) + tmp12 = tmp6.to(tl.int64) + tmp13 = tmp12.to(tl.int32) + tmp14 = tmp11.to(tl.int32) + tl.store(out_ptr2 + (r0_2 + 16*x0 + 16*x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp13, xmask) + tl.store(out_ptr3 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp14, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/y5/cy5vf4gnz7zle5ospiqwxpaakdrocn2nwwwhum3styir6fju6b6t.py b/SpecForge-ext/cache/compiled_kernels/y5/cy5vf4gnz7zle5ospiqwxpaakdrocn2nwwwhum3styir6fju6b6t.py new file mode 100644 index 0000000000000000000000000000000000000000..21321c7a6c869729f85eaed86fe6adff93b5eee7 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/y5/cy5vf4gnz7zle5ospiqwxpaakdrocn2nwwwhum3styir6fju6b6t.py @@ -0,0 +1,57 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 16384, 'r0_': 262144}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_out_ptr0': '*i64', 'in_ptr0': '*bf16', 'in_ptr1': '*i1', 'in_ptr2': '*i64', 'xnumel': 'i64', 'r0_numel': 'i64', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 16384 + r0_numel = 151936 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0).to(tl.int64) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None].to(tl.int64) + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :].to(tl.int64) + rbase = r0_base + x0 = xindex + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 9223372036854775807, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 151936*x0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tmp11 = tl.load(in_ptr2 + (x0), None, eviction_policy='evict_last') + tmp3 = tl.full([XBLOCK, 1], 151936, tl.int32) + tmp4 = tmp2 + tmp3 + tmp5 = tmp2 < 0 + tmp6 = tl.where(tmp5, tmp4, tmp2) + tl.device_assert((0 <= tmp6) & (tmp6 < 151936), "index out of bounds: 0 <= tmp6 < 151936") + tmp8 = tl.load(in_ptr1 + (tmp6), None, eviction_policy='evict_last').to(tl.int1) + tmp9 = tmp8.to(tl.int32) + tmp10 = tmp9.to(tl.int64) + tmp12 = tmp10 * tmp11 + tl.debug_barrier() + tl.store(in_out_ptr0 + (x0), tmp12, None) diff --git a/SpecForge-ext/cache/compiled_kernels/ye/cyea7zwtzaxv6b2klni7blgyz6kwza7fz6xa27qfd5prfdbgw4p4.py b/SpecForge-ext/cache/compiled_kernels/ye/cyea7zwtzaxv6b2klni7blgyz6kwza7fz6xa27qfd5prfdbgw4p4.py new file mode 100644 index 0000000000000000000000000000000000000000..2f045a6decad1e3d2b9025244820c950990299b1 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/ye/cyea7zwtzaxv6b2klni7blgyz6kwza7fz6xa27qfd5prfdbgw4p4.py @@ -0,0 +1,1051 @@ +# AOT ID: ['6_backward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/5y/c5youjzyi3z3ynjz75h25htk6unkxftgtdsn4apk4b37ykfisbjl.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] +# Source node to ATen node mapping: +# Graph fragment: +# %getitem : Tensor "bf16[2, 32, 2048, 128][8388608, 128, 4096, 1]cuda:5" = PlaceHolder[target=getitem] +# %tangents_1 : Tensor "bf16[2, 32, 2048, 128][8388608, 262144, 128, 1]cuda:5" = PlaceHolder[target=tangents_1] +# %buf0 : Tensor "bf16[2, 32, 2048][65536, 2048, 1]cuda:5" = PlaceHolder[target=buf0] +# %full_default : Tensor "f32[2, 32, 2048][65536, 2048, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 32, 2048], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:5, pin_memory: False}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_1, %primals_2, %primals_3, %getitem, %getitem_1, %tangents_1, %full_default, %fw_graph0, %joint_graph0, (2048, 2048, %primals_5, %primals_4, %primals_7, %primals_8, %primals_9, %primals_10, %primals_11, %primals_12, 128, 128, %mask_graph0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_6,)), kwargs = {}) +# return %buf0,%buf1 +triton_red_fused_zeros_0 = async_compile.triton('triton_red_fused_zeros_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 131072, 'r0_': 128}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr1': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 1048576, 'r0_': 67108864}} +) +@triton.jit +def triton_red_fused_zeros_0(in_ptr0, in_ptr1, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 131072 + r0_numel = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % 2048) + x1 = ((xindex // 2048) % 32) + x2 = xindex // 65536 + x4 = xindex + _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_3 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_3 + 128*x1 + 4096*x0 + 8388608*x2), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_3 + 128*x4), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = _tmp4 + tmp3 + _tmp4 = tl.where(r0_mask, tmp5, _tmp4) + tmp4 = tl.sum(_tmp4, 1)[:, None] + tmp6 = tmp4.to(tl.float32) + tmp7 = 0.0 + tmp8 = tmp6 - tmp7 + tl.store(out_ptr1 + (x4), tmp8, None) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxcw73ulf57rliruk2q2qa2i7oeoplgvcxvn5jg5wjp4b2ksrxd4.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] +# Source node to ATen node mapping: +# Graph fragment: +# %primals_1 : Tensor "bf16[2, 32, 2048, 128][8388608, 128, 4096, 1]cuda:5" = PlaceHolder[target=primals_1] +# %primals_2 : Tensor "bf16[2, 8, 2048, 128][2097152, 262144, 128, 1]cuda:5" = PlaceHolder[target=primals_2] +# %primals_3 : Tensor "bf16[2, 8, 2048, 128][2097152, 262144, 128, 1]cuda:5" = PlaceHolder[target=primals_3] +# %getitem_1 : Tensor "f32[2, 32, 2048][65536, 2048, 1]cuda:5" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[2, 32, 2048][65536, 2048, 1]cuda:5" = PlaceHolder[target=buf1] +# %tangents_1 : Tensor "bf16[2, 32, 2048, 128][8388608, 262144, 128, 1]cuda:5" = PlaceHolder[target=tangents_1] +# %getitem_3 : Tensor "bf16[2, 32, 2048, 128][8388608, 128, 4096, 1]cuda:5" = PlaceHolder[target=getitem_3] +# %getitem_5 : Tensor "bf16[2, 8, 2048, 128][2097152, 262144, 128, 1]cuda:5" = PlaceHolder[target=getitem_5] +# %primals_5 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:5" = PlaceHolder[target=primals_5] +# %primals_4 : Tensor "i32[2, 1, 16, 16][256, 256, 16, 1]cuda:5" = PlaceHolder[target=primals_4] +# %primals_9 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:5" = PlaceHolder[target=primals_9] +# %primals_10 : Tensor "i32[2, 1, 16, 16][256, 256, 16, 1]cuda:5" = PlaceHolder[target=primals_10] +# %primals_7 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:5" = PlaceHolder[target=primals_7] +# %primals_8 : Tensor "i32[2, 1, 16, 16][256, 256, 16, 1]cuda:5" = PlaceHolder[target=primals_8] +# %primals_11 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:5" = PlaceHolder[target=primals_11] +# %primals_12 : Tensor "i32[2, 1, 16, 16][256, 256, 16, 1]cuda:5" = PlaceHolder[target=primals_12] +# %primals_6 : Tensor "i64[2][1]cuda:5" = PlaceHolder[target=primals_6] +# %full_default : Tensor "f32[2, 32, 2048][65536, 2048, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 32, 2048], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:5, pin_memory: False}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_1, %primals_2, %primals_3, %getitem, %getitem_1, %tangents_1, %full_default, %fw_graph0, %joint_graph0, (2048, 2048, %primals_5, %primals_4, %primals_7, %primals_8, %primals_9, %primals_10, %primals_11, %primals_12, 128, 128, %mask_graph0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_6,)), kwargs = {}) +# return %getitem_4 +triton_tem_fused_zeros_1 = async_compile.triton('triton_tem_fused_zeros_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_zeros_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 2097152, 262144, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 2097152, 262144, 128, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 8388608, 262144, 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 8388608, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 2097152, 262144, 128, 1 + + ZQ = 2 + HQ = 32 + HKV = 8 + Q_LEN = 2048 + ZKV = 2 + KV_LEN = 2048 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 2 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 256 + stride_kv_idx_m = 16 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 16 + stride_q_idx_h = 256 + stride_q_idx_n = 16 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 262144*off_hkv + 2097152*off_zq + tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = 2048 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr16 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp39 = (ds) + grad_scores = tmp39 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = 2048 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp40 = (qkT) + post_mod_scores = tmp40 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp41 = tl.full([1], False, tl.int1) + tmp42 = (m) + tmp43 = (n) + tmp44 = tmp42 >= tmp43 + tmp45 = tmp43.to(tl.int64) + tmp46 = (off_z) + tmp47 = tl.load(in_ptr16 + tmp46) + tmp48 = tmp45 < tmp47 + tmp49 = tmp42.to(tl.int64) + tmp50 = tmp49 < tmp47 + tmp51 = tmp48 & tmp50 + tmp52 = tmp44 & tmp51 + tmp53 = tmp41 | tmp52 + tmp54 = tl.full([1], 2048, tl.int32) + tmp55 = tmp43 >= tmp54 + tmp56 = (tmp43 % tmp54) + tmp57 = tl.full([1], 0, tl.int32) + tmp58 = tmp56 != tmp57 + tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0 + tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0 + tmp61 = tmp59 != tmp60 + tmp62 = tmp58 & tmp61 + tmp63 = tmp56 + tmp54 + tmp64 = tl.where(tmp62, tmp63, tmp56) + tmp65 = tmp64.to(tl.int64) + tmp66 = tmp65 < tmp47 + tmp67 = tmp55 & tmp66 + tmp68 = tmp43 - tmp42 + tmp69 = (tmp68 % tmp54) + tmp70 = tmp69 != tmp57 + tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0 + tmp72 = tmp71 != tmp60 + tmp73 = tmp70 & tmp72 + tmp74 = tmp69 + tmp54 + tmp75 = tl.where(tmp73, tmp74, tmp69) + tmp76 = tmp75 == tmp57 + tmp77 = tmp67 & tmp76 + tmp78 = tmp53 | tmp77 + mask_mod_output = tmp78 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp79 = (dsT) + grad_scores = tmp79 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, getitem, getitem_1, tangents_1 = args + args.clear() + assert_size_stride(primals_1, (2, 32, 2048, 128), (8388608, 128, 4096, 1)) + assert_size_stride(primals_2, (2, 8, 2048, 128), (2097152, 262144, 128, 1)) + assert_size_stride(primals_3, (2, 8, 2048, 128), (2097152, 262144, 128, 1)) + assert_size_stride(primals_4, (2, 1, 16, 16), (256, 256, 16, 1)) + assert_size_stride(primals_5, (2, 1, 16), (16, 16, 1)) + assert_size_stride(primals_6, (2, ), (1, )) + assert_size_stride(primals_7, (2, 1, 16), (16, 16, 1)) + assert_size_stride(primals_8, (2, 1, 16, 16), (256, 256, 16, 1)) + assert_size_stride(primals_9, (2, 1, 16), (16, 16, 1)) + assert_size_stride(primals_10, (2, 1, 16, 16), (256, 256, 16, 1)) + assert_size_stride(primals_11, (2, 1, 16), (16, 16, 1)) + assert_size_stride(primals_12, (2, 1, 16, 16), (256, 256, 16, 1)) + assert_size_stride(getitem, (2, 32, 2048, 128), (8388608, 128, 4096, 1)) + assert_size_stride(getitem_1, (2, 32, 2048), (65536, 2048, 1)) + assert_size_stride(tangents_1, (2, 32, 2048, 128), (8388608, 262144, 128, 1)) + with torch.cuda._DeviceGuard(5): + torch.cuda.set_device(5) + buf1 = empty_strided_cuda((2, 32, 2048), (65536, 2048, 1), torch.float32) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] + stream5 = get_raw_stream(5) + triton_red_fused_zeros_0.run(getitem, tangents_1, buf1, 131072, 128, stream=stream5) + del getitem + buf3 = empty_strided_cuda((2, 32, 2048, 128), (8388608, 128, 4096, 1), torch.bfloat16) + buf4 = empty_strided_cuda((2, 8, 2048, 128), (2097152, 262144, 128, 1), torch.bfloat16) + buf5 = empty_strided_cuda((2, 8, 2048, 128), (2097152, 262144, 128, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] + stream5 = get_raw_stream(5) + triton_tem_fused_zeros_1.run(primals_1, primals_2, primals_3, getitem_1, buf1, tangents_1, buf3, buf4, primals_5, primals_4, primals_9, primals_10, primals_7, primals_8, primals_11, primals_12, primals_6, buf5, 80, 2, 8, stream=stream5) + del buf1 + del getitem_1 + del primals_1 + del primals_10 + del primals_11 + del primals_12 + del primals_2 + del primals_3 + del primals_4 + del primals_5 + del primals_6 + del primals_7 + del primals_8 + del primals_9 + del tangents_1 + return (buf3, buf5, buf4, None, None, None, None, None, None, None, None, None, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = rand_strided((2, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:5', dtype=torch.bfloat16) + primals_2 = rand_strided((2, 8, 2048, 128), (2097152, 262144, 128, 1), device='cuda:5', dtype=torch.bfloat16) + primals_3 = rand_strided((2, 8, 2048, 128), (2097152, 262144, 128, 1), device='cuda:5', dtype=torch.bfloat16) + primals_4 = rand_strided((2, 1, 16, 16), (256, 256, 16, 1), device='cuda:5', dtype=torch.int32) + primals_5 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:5', dtype=torch.int32) + primals_6 = rand_strided((2, ), (1, ), device='cuda:5', dtype=torch.int64) + primals_7 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:5', dtype=torch.int32) + primals_8 = rand_strided((2, 1, 16, 16), (256, 256, 16, 1), device='cuda:5', dtype=torch.int32) + primals_9 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:5', dtype=torch.int32) + primals_10 = rand_strided((2, 1, 16, 16), (256, 256, 16, 1), device='cuda:5', dtype=torch.int32) + primals_11 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:5', dtype=torch.int32) + primals_12 = rand_strided((2, 1, 16, 16), (256, 256, 16, 1), device='cuda:5', dtype=torch.int32) + getitem = rand_strided((2, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:5', dtype=torch.bfloat16) + getitem_1 = rand_strided((2, 32, 2048), (65536, 2048, 1), device='cuda:5', dtype=torch.float32) + tangents_1 = rand_strided((2, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:5', dtype=torch.bfloat16) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, getitem, getitem_1, tangents_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/ye/cyekdgh4wmn74jvtympiiylqnrvhwts56hfa325wb2lv7qflcp6o.py b/SpecForge-ext/cache/compiled_kernels/ye/cyekdgh4wmn74jvtympiiylqnrvhwts56hfa325wb2lv7qflcp6o.py new file mode 100644 index 0000000000000000000000000000000000000000..48f1a3d2647a5222dc14ae063adbf8aa49792413 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/ye/cyekdgh4wmn74jvtympiiylqnrvhwts56hfa325wb2lv7qflcp6o.py @@ -0,0 +1,159 @@ +# AOT ID: ['1_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/46/c46dhrndu3aq5tou7kbkolfu6xneb3gy6qyyi2vbwptr5jq223bs.py +# Topologically Sorted Source Nodes: [target_head, target_p], Original ATen: [aten._to_copy, prims.prepare_softmax_online, aten.sub, aten.exp, aten._softmax] +# Source node to ATen node mapping: +# target_head => convert_element_type +# target_p => div +# Graph fragment: +# %arg0_1 : Tensor "bf16[2, 2048, 32000][65536000, 32000, 1]cuda:5" = PlaceHolder[target=arg0_1] +# %getitem : Tensor "f32[2, 2048, 1][2048, 1, 4096]cuda:5" = PlaceHolder[target=getitem] +# %getitem_1 : Tensor "f32[2, 2048, 1][2048, 1, 4096]cuda:5" = PlaceHolder[target=getitem_1] +# %convert_element_type : Tensor "f32[2, 2048, 32000][65536000, 32000, 1]cuda:5"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%arg0_1, torch.float32), kwargs = {}) +# %prepare_softmax_online_default : [num_users=2] = call_function[target=torch.ops.prims.prepare_softmax_online.default](args = (%convert_element_type, 2), kwargs = {}) +# %sub_tensor : Tensor "f32[2, 2048, 32000][65536000, 32000, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%convert_element_type, %getitem), kwargs = {}) +# %exp_default : Tensor "f32[2, 2048, 32000][65536000, 32000, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.exp.default](args = (%sub_tensor,), kwargs = {}) +# %div : Tensor "f32[2, 2048, 32000][65536000, 32000, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%exp_default, %getitem_1), kwargs = {}) +# return %getitem,%getitem_1,%div +triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0 = async_compile.triton('triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr2': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'add_persistent_rblock': True, 'tiling_scores': {'x': 0, 'r0_': 1310720000}} +) +@triton.jit +def triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0(in_ptr0, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 4096 + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp3_max = tl.full([XBLOCK, R0_BLOCK], float('-inf'), tl.float32) + _tmp3_sum = tl.zeros([XBLOCK, R0_BLOCK], tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp1 = tmp0.to(tl.float32) + tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) + + _tmp3_max_next, _tmp3_sum_next = triton_helpers.online_softmax_combine( + _tmp3_max, _tmp3_sum, tmp2, False + ) + + _tmp3_max = tl.where(r0_mask, _tmp3_max_next, _tmp3_max) + _tmp3_sum = tl.where(r0_mask, _tmp3_sum_next, _tmp3_sum) + + tmp3, tmp4 = triton_helpers.online_softmax_reduce( + _tmp3_max, _tmp3_sum, 1, False) + tmp3 = tmp3[:, None] + tmp4 = tmp4[:, None] + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp5 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp6 = tmp5.to(tl.float32) + tmp7 = tmp6 - tmp3 + tmp8 = libdevice.exp(tmp7) + tmp9 = (tmp8 / tmp4) + tl.store(out_ptr2 + (r0_1 + 32000*x0), tmp9, r0_mask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, = args + args.clear() + assert_size_stride(arg0_1, (2, 2048, 32000), (65536000, 32000, 1)) + with torch.cuda._DeviceGuard(5): + torch.cuda.set_device(5) + buf2 = empty_strided_cuda((2, 2048, 32000), (65536000, 32000, 1), torch.float32) + # Topologically Sorted Source Nodes: [target_head, target_p], Original ATen: [aten._to_copy, prims.prepare_softmax_online, aten.sub, aten.exp, aten._softmax] + stream5 = get_raw_stream(5) + triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0.run(arg0_1, buf2, 4096, 32000, stream=stream5) + del arg0_1 + return (buf2, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = rand_strided((2, 2048, 32000), (65536000, 32000, 1), device='cuda:5', dtype=torch.bfloat16) + fn = lambda: call([arg0_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/yh/cyhmejo3vybiioddrnxvzkiyv7k7mq4q42mtldpgossn56vqsii2.py b/SpecForge-ext/cache/compiled_kernels/yh/cyhmejo3vybiioddrnxvzkiyv7k7mq4q42mtldpgossn56vqsii2.py new file mode 100644 index 0000000000000000000000000000000000000000..4a359a65b01ea4108eb849a27eea3d8cb7809ec2 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/yh/cyhmejo3vybiioddrnxvzkiyv7k7mq4q42mtldpgossn56vqsii2.py @@ -0,0 +1,693 @@ +# AOT ID: ['9_forward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +from torch._C import _cuda_getCurrentRawStream as get_raw_stream +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/4u/c4uf4o6eypfpqr4isgii4opqr5i3brobwecljte7sqvztk2kyafz.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %primals_1 : Tensor "bf16[2, 32, 2048, 128][8388608, 128, 4096, 1]cuda:1" = PlaceHolder[target=primals_1] +# %primals_3 : Tensor "bf16[2, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:1" = PlaceHolder[target=primals_3] +# %primals_5 : Tensor "bf16[2, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:1" = PlaceHolder[target=primals_5] +# %getitem_1 : Tensor "f32[2, 32, 2048][65536, 2048, 1]cuda:1" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[2, 32, 2048][65536, 2048, 1]cuda:1" = PlaceHolder[target=buf1] +# %primals_9 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:1" = PlaceHolder[target=primals_9] +# %primals_7 : Tensor "i32[2, 1, 16, s72][16*s72, 16*s72, s72, 1]cuda:1" = PlaceHolder[target=primals_7] +# %primals_11 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:1" = PlaceHolder[target=primals_11] +# %primals_13 : Tensor "i32[2, 1, 16, s4][16*s4, 16*s4, s4, 1]cuda:1" = PlaceHolder[target=primals_13] +# %primals_10 : Tensor "i64[2][1]cuda:1" = PlaceHolder[target=primals_10] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%primals_1, %primals_3, %primals_5, %sdpa_score0, (2048, %primals_8, %primals_9, %primals_7, %primals_11, %primals_13, %primals_15, %primals_17, %primals_19, %primals_21, 128, 128, %sdpa_mask0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_10,)), kwargs = {}) +# return %getitem +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'in_ptr9': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks0, 128*ks0, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks0, 128*ks0, 128, 1 + + ZQ = 2 + HQ = 32 + Q_LEN = 2048 + ZKV = 2 + KV_LEN = ks0 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 2 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 16*ks1 + stride_kv_idx_m = ks1 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 262144*idx_hq + 8388608*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m + 8388608*idx_zq, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr9 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21 = args + args.clear() + s0 = primals_2 + s43 = primals_4 + s72 = primals_6 + s71 = primals_8 + s4 = primals_12 + s56 = primals_14 + s84 = primals_16 + s99 = primals_18 + s6 = primals_20 + assert_size_stride(primals_1, (2, 32, 2048, 128), (8388608, 128, 4096, 1)) + assert_size_stride(primals_3, (2, 8, s0, 128), (1024*s0, 128*s0, 128, 1)) + assert_size_stride(primals_5, (2, 8, s0, 128), (1024*s0, 128*s0, 128, 1)) + assert_size_stride(primals_7, (2, 1, 16, s72), (16*s72, 16*s72, s72, 1)) + assert_size_stride(primals_9, (2, 1, 16), (16, 16, 1)) + assert_size_stride(primals_10, (2, ), (1, )) + assert_size_stride(primals_11, (2, 1, 16), (16, 16, 1)) + assert_size_stride(primals_13, (2, 1, 16, s4), (16*s4, 16*s4, s4, 1)) + assert_size_stride(primals_15, (2, 1, s56), (s56, s56, 1)) + assert_size_stride(primals_17, (2, 1, s84, 16), (16*s84, 16*s84, 16, 1)) + assert_size_stride(primals_19, (2, 1, s99), (s99, s99, 1)) + assert_size_stride(primals_21, (2, 1, s6, 16), (16*s6, 16*s6, 16, 1)) + with torch.cuda._DeviceGuard(1): + torch.cuda.set_device(1) + buf0 = empty_strided_cuda((2, 32, 2048), (65536, 2048, 1), torch.float32) + buf1 = empty_strided_cuda((2, 32, 2048), (65536, 2048, 1), torch.float32) + buf2 = empty_strided_cuda((2, 32, 2048, 128), (8388608, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream1 = get_raw_stream(1) + triton_tem_fused_0.run(primals_1, primals_3, primals_5, buf0, buf1, primals_9, primals_7, primals_11, primals_13, primals_10, buf2, s0, s72, 16, 2, 32, stream=stream1) + del buf1 + return (buf2, primals_1, primals_3, primals_5, primals_7, primals_9, primals_10, primals_11, primals_13, primals_15, primals_17, primals_19, primals_21, buf2, buf0, s0, s72, s4, s56, s84, s99, s6, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = rand_strided((2, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:1', dtype=torch.bfloat16) + primals_2 = 4096 + primals_3 = rand_strided((2, 8, 4096, 128), (4194304, 524288, 128, 1), device='cuda:1', dtype=torch.bfloat16) + primals_4 = 4096 + primals_5 = rand_strided((2, 8, 4096, 128), (4194304, 524288, 128, 1), device='cuda:1', dtype=torch.bfloat16) + primals_6 = 32 + primals_7 = rand_strided((2, 1, 16, 32), (512, 512, 32, 1), device='cuda:1', dtype=torch.int32) + primals_8 = 4096 + primals_9 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:1', dtype=torch.int32) + primals_10 = rand_strided((2, ), (1, ), device='cuda:1', dtype=torch.int64) + primals_11 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:1', dtype=torch.int32) + primals_12 = 32 + primals_13 = rand_strided((2, 1, 16, 32), (512, 512, 32, 1), device='cuda:1', dtype=torch.int32) + primals_14 = 32 + primals_15 = rand_strided((2, 1, 32), (32, 32, 1), device='cuda:1', dtype=torch.int32) + primals_16 = 32 + primals_17 = rand_strided((2, 1, 32, 16), (512, 512, 16, 1), device='cuda:1', dtype=torch.int32) + primals_18 = 32 + primals_19 = rand_strided((2, 1, 32), (32, 32, 1), device='cuda:1', dtype=torch.int32) + primals_20 = 32 + primals_21 = rand_strided((2, 1, 32, 16), (512, 512, 16, 1), device='cuda:1', dtype=torch.int32) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/ym/b9f4684b59d9715ca7689ae6bc338522329d6df5de39bb88d4f7fb7848e92ae1.best_config b/SpecForge-ext/cache/compiled_kernels/ym/b9f4684b59d9715ca7689ae6bc338522329d6df5de39bb88d4f7fb7848e92ae1.best_config new file mode 100644 index 0000000000000000000000000000000000000000..e9b96a126fb37b684d7d003c0adf1a0efd4c8fc6 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/ym/b9f4684b59d9715ca7689ae6bc338522329d6df5de39bb88d4f7fb7848e92ae1.best_config @@ -0,0 +1 @@ +{"XBLOCK": 512, "num_warps": 8, "num_stages": 1, "configs_hash": "3ca5c3e34d35093f3c9ab2829a9faeebad5e61c4ca13d5ed6053d7b71ce60d5a", "found_by_coordesc": false, "time_taken_ms": 20, "triton_cache_hash": "Z2RWAHMO7VUWQKIIRA5A46JYV2SEXHWLKREQM7TOP6VGUWDXAYAQ"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/ym/cymroqqgat67pxc4rnyt5dzn7thy5u4fbc3jl5cbhuh5rx3hz3bb.py b/SpecForge-ext/cache/compiled_kernels/ym/cymroqqgat67pxc4rnyt5dzn7thy5u4fbc3jl5cbhuh5rx3hz3bb.py new file mode 100644 index 0000000000000000000000000000000000000000..3a94f15174b0a041bcf33d97534c1504199cf51c --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/ym/cymroqqgat67pxc4rnyt5dzn7thy5u4fbc3jl5cbhuh5rx3hz3bb.py @@ -0,0 +1,56 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 4194304}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'ks4': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, ks4, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x4 = xindex + x2 = ((xindex // ks0) % ks1) + x0 = (xindex % ks3) + x5 = xindex // ks3 + tmp0 = tl.load(in_ptr0 + (x4), xmask, eviction_policy='evict_last').to(tl.float32) + tmp1 = tl.load(in_ptr1 + (x2), xmask, eviction_policy='evict_last') + tmp2 = ks2 + tmp3 = tmp1 + tmp2 + tmp4 = tmp1 < 0 + tmp5 = tl.where(tmp4, tmp3, tmp1) + tl.device_assert(((0 <= tmp5) & (tmp5 < ks2)) | ~(xmask), "index out of bounds: 0 <= tmp5 < ks2") + tmp7 = tl.load(in_ptr2 + (x0 + ks3*tmp5), xmask, eviction_policy='evict_last').to(tl.float32) + tmp8 = tmp0 * tmp7 + tmp9 = x0 + tmp10 = tl.full([1], 0, tl.int64) + tmp11 = tmp9 >= tmp10 + tmp12 = ks3 + (-1)*(ks3 // 2) + tmp13 = tmp9 < tmp12 + tmp14 = tl.load(in_ptr0 + (ks3*x5 + (ks3 // 2) + (x0)), tmp13 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp15 = -tmp14 + tmp16 = tl.full(tmp15.shape, 0.0, tmp15.dtype) + tmp17 = tl.where(tmp13, tmp15, tmp16) + tmp18 = tmp9 >= tmp12 + tmp19 = ks3 + tmp20 = tmp9 < tmp19 + tmp21 = tl.load(in_ptr0 + (ks3*x5 + (x0 + ((-1)*ks3) + (ks3 // 2))), tmp18 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp22 = tl.where(tmp13, tmp17, tmp21) + tmp23 = ks4 + tmp24 = tmp1 + tmp23 + tmp25 = tl.where(tmp4, tmp24, tmp1) + tl.device_assert(((0 <= tmp25) & (tmp25 < ks4)) | ~(xmask), "index out of bounds: 0 <= tmp25 < ks4") + tmp27 = tl.load(in_ptr3 + (x0 + ks3*tmp25), xmask, eviction_policy='evict_last').to(tl.float32) + tmp28 = tmp22 * tmp27 + tmp29 = tmp8 + tmp28 + tl.store(out_ptr0 + (x4), tmp29, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/ys/cyswjqjzlfkq4xdob2fved3zkdu7tdip5o73b34z2gdtmnr2vruw.py b/SpecForge-ext/cache/compiled_kernels/ys/cyswjqjzlfkq4xdob2fved3zkdu7tdip5o73b34z2gdtmnr2vruw.py new file mode 100644 index 0000000000000000000000000000000000000000..bb010ed4d7efa36f6f0689c3a1b6600a367628b5 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/ys/cyswjqjzlfkq4xdob2fved3zkdu7tdip5o73b34z2gdtmnr2vruw.py @@ -0,0 +1,711 @@ +# AOT ID: ['13_forward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +from torch._C import _cuda_getCurrentRawStream as get_raw_stream +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/4m/c4mv34wib446qhr7sd5yhgc4mdneb7isnb6uitnbwvdgrbpgyf2s.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %primals_2 : Tensor "bf16[2, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:3" = PlaceHolder[target=primals_2] +# %primals_4 : Tensor "bf16[2, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:3" = PlaceHolder[target=primals_4] +# %primals_6 : Tensor "bf16[2, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:3" = PlaceHolder[target=primals_6] +# %getitem_1 : Tensor "f32[2, 32, s37][32*s37, s37, 1]cuda:3" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[2, 32, s37][32*s37, s37, 1]cuda:3" = PlaceHolder[target=buf1] +# %primals_13 : Tensor "i32[2, 1, s99][s99, s99, 1]cuda:3" = PlaceHolder[target=primals_13] +# %primals_9 : Tensor "i32[2, 1, s22, s72][s22*s72, s22*s72, s72, 1]cuda:3" = PlaceHolder[target=primals_9] +# %primals_17 : Tensor "i32[2, 1, s94][s94, s94, 1]cuda:3" = PlaceHolder[target=primals_17] +# %primals_20 : Tensor "i32[2, 1, s28, s4][s28*s4, s28*s4, s4, 1]cuda:3" = PlaceHolder[target=primals_20] +# %primals_14 : Tensor "i64[2][1]cuda:3" = PlaceHolder[target=primals_14] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%primals_2, %primals_4, %primals_6, %sdpa_score0, (%primals_10, %primals_11, %primals_13, %primals_9, %primals_17, %primals_20, %primals_22, %primals_25, %primals_27, %primals_30, 128, 128, %sdpa_mask0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_14, %primals_15)), kwargs = {}) +# return %getitem +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'in_ptr9': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128*ks1, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128*ks1, 128, 1 + + ZQ = 2 + HQ = 32 + Q_LEN = ks0 + ZKV = 2 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 2 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m + 4096*idx_zq*ks0, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr9 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = ks5 + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21, primals_22, primals_23, primals_24, primals_25, primals_26, primals_27, primals_28, primals_29, primals_30 = args + args.clear() + s50 = primals_1 + s0 = primals_3 + s43 = primals_5 + s22 = primals_7 + s72 = primals_8 + s37 = primals_10 + s71 = primals_11 + s99 = primals_12 + s75 = primals_15 + s94 = primals_16 + s28 = primals_18 + s4 = primals_19 + s56 = primals_21 + s84 = primals_23 + s53 = primals_24 + s100 = primals_26 + s6 = primals_28 + s10 = primals_29 + assert_size_stride(primals_2, (2, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(primals_4, (2, 8, s0, 128), (1024*s0, 128*s0, 128, 1)) + assert_size_stride(primals_6, (2, 8, s0, 128), (1024*s0, 128*s0, 128, 1)) + assert_size_stride(primals_9, (2, 1, s22, s72), (s22*s72, s22*s72, s72, 1)) + assert_size_stride(primals_13, (2, 1, s99), (s99, s99, 1)) + assert_size_stride(primals_14, (2, ), (1, )) + assert_size_stride(primals_17, (2, 1, s94), (s94, s94, 1)) + assert_size_stride(primals_20, (2, 1, s28, s4), (s28*s4, s28*s4, s4, 1)) + assert_size_stride(primals_22, (2, 1, s56), (s56, s56, 1)) + assert_size_stride(primals_25, (2, 1, s84, s53), (s53*s84, s53*s84, s53, 1)) + assert_size_stride(primals_27, (2, 1, s100), (s100, s100, 1)) + assert_size_stride(primals_30, (2, 1, s6, s10), (s10*s6, s10*s6, s10, 1)) + with torch.cuda._DeviceGuard(3): + torch.cuda.set_device(3) + buf0 = empty_strided_cuda((2, 32, s37), (32*s37, s37, 1), torch.float32) + buf1 = empty_strided_cuda((2, 32, s37), (32*s37, s37, 1), torch.float32) + buf2 = empty_strided_cuda((2, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream3 = get_raw_stream(3) + triton_tem_fused_0.run(primals_2, primals_4, primals_6, buf0, buf1, primals_13, primals_9, primals_17, primals_20, primals_14, buf2, s37, s0, s99, s22, s72, s75, (127 + s37) // 128, 2, 32, stream=stream3) + del buf1 + return (buf2, primals_2, primals_4, primals_6, primals_9, primals_13, primals_14, primals_17, primals_20, primals_22, primals_25, primals_27, primals_30, buf2, buf0, s37, s0, s75, s22, s72, s99, s94, s28, s4, s56, s53, s84, s100, s10, s6, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = 2014 + primals_2 = rand_strided((2, 32, 2014, 128), (8249344, 128, 4096, 1), device='cuda:3', dtype=torch.bfloat16) + primals_3 = 2014 + primals_4 = rand_strided((2, 8, 2014, 128), (2062336, 257792, 128, 1), device='cuda:3', dtype=torch.bfloat16) + primals_5 = 2014 + primals_6 = rand_strided((2, 8, 2014, 128), (2062336, 257792, 128, 1), device='cuda:3', dtype=torch.bfloat16) + primals_7 = 16 + primals_8 = 16 + primals_9 = rand_strided((2, 1, 16, 16), (256, 256, 16, 1), device='cuda:3', dtype=torch.int32) + primals_10 = 2014 + primals_11 = 2014 + primals_12 = 16 + primals_13 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:3', dtype=torch.int32) + primals_14 = rand_strided((2, ), (1, ), device='cuda:3', dtype=torch.int64) + primals_15 = 2014 + primals_16 = 16 + primals_17 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:3', dtype=torch.int32) + primals_18 = 16 + primals_19 = 16 + primals_20 = rand_strided((2, 1, 16, 16), (256, 256, 16, 1), device='cuda:3', dtype=torch.int32) + primals_21 = 16 + primals_22 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:3', dtype=torch.int32) + primals_23 = 16 + primals_24 = 16 + primals_25 = rand_strided((2, 1, 16, 16), (256, 256, 16, 1), device='cuda:3', dtype=torch.int32) + primals_26 = 16 + primals_27 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:3', dtype=torch.int32) + primals_28 = 16 + primals_29 = 16 + primals_30 = rand_strided((2, 1, 16, 16), (256, 256, 16, 1), device='cuda:3', dtype=torch.int32) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21, primals_22, primals_23, primals_24, primals_25, primals_26, primals_27, primals_28, primals_29, primals_30]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/yu/345bee7d0ba798dda05ca8d9c0d89fab5c3e9287e332b4cd143934196f13958b.best_config b/SpecForge-ext/cache/compiled_kernels/yu/345bee7d0ba798dda05ca8d9c0d89fab5c3e9287e332b4cd143934196f13958b.best_config new file mode 100644 index 0000000000000000000000000000000000000000..b5fe0bd195e2afa4eb939871edb76221f1e8606e --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/yu/345bee7d0ba798dda05ca8d9c0d89fab5c3e9287e332b4cd143934196f13958b.best_config @@ -0,0 +1 @@ +{"XBLOCK": 1, "num_warps": 2, "num_stages": 1, "configs_hash": "b6ac5ef64fddcad8fc8d2c05fa12424871fd9baa5a4158ff38ecebbafb55a4b1", "found_by_coordesc": false, "time_taken_ms": 25, "triton_cache_hash": "E2MI47QNGZ2SJDA3U3EKHN7H3EYRAANF6T7N5SFT2CZJYNBAWCNQ"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/yu/cyuwuklbk6glbjmohlzaemm25tilic35nartuvsd4hkf3pb4ixgd.py b/SpecForge-ext/cache/compiled_kernels/yu/cyuwuklbk6glbjmohlzaemm25tilic35nartuvsd4hkf3pb4ixgd.py new file mode 100644 index 0000000000000000000000000000000000000000..be6178b5b152d9743c940170ebff9804b381892b --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/yu/cyuwuklbk6glbjmohlzaemm25tilic35nartuvsd4hkf3pb4ixgd.py @@ -0,0 +1,50 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 128, 'r0_': 16}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr2': '*i32', 'out_ptr3': '*i32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 1024, 'r0_': 16384}} +) +@triton.jit +def triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3(in_ptr0, out_ptr2, out_ptr3, xnumel, r0_numel, XBLOCK : tl.constexpr): + xnumel = 128 + r0_numel = 16 + R0_BLOCK: tl.constexpr = 16 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + x0 = (xindex % 16) + x1 = xindex // 16 + x3 = xindex + tmp0 = tl.load(in_ptr0 + (x0 + 17*r0_2 + 272*x1), xmask, other=0.0) + tmp1 = r0_2 + tmp2 = tmp1.to(tl.int16) + tmp3 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp4 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5, tmp6, = triton_helpers.sort_with_index(tmp3, tmp4, None, 1, stable=True, descending=True) + tmp7 = tmp0.to(tl.int64) + tmp8 = tl.broadcast_to(tmp7, [XBLOCK, R0_BLOCK]) + tmp10 = tl.where(xmask, tmp8, 0) + tmp11 = tl.sum(tmp10, 1)[:, None].to(tl.int64) + tmp12 = tmp6.to(tl.int64) + tmp13 = tmp12.to(tl.int32) + tmp14 = tmp11.to(tl.int32) + tl.store(out_ptr2 + (r0_2 + 16*x3), tmp13, xmask) + tl.store(out_ptr3 + (x3), tmp14, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/yx/cyx32mrbk6dmmczjafautcv3nsbj2s4rr6zdyb6ie3lu622orrdx.py b/SpecForge-ext/cache/compiled_kernels/yx/cyx32mrbk6dmmczjafautcv3nsbj2s4rr6zdyb6ie3lu622orrdx.py new file mode 100644 index 0000000000000000000000000000000000000000..55d18767463e82a257791df2c745c5254e096a8d --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/yx/cyx32mrbk6dmmczjafautcv3nsbj2s4rr6zdyb6ie3lu622orrdx.py @@ -0,0 +1,99 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 2048, 'r0_': 16384}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr1': '*i32', 'out_ptr2': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'ks4': 'i64', 'ks5': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1(in_ptr0, out_ptr1, out_ptr2, ks0, ks1, ks2, ks3, ks4, ks5, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 16384 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x1 = ((xindex // ks0) % ks1) + x0 = (xindex % ks0) + x2 = xindex // ks4 + _tmp46 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + x5 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_4 = r0_index // 128 + r0_3 = (r0_index % 128) + tmp0 = r0_4 + 128*x1 + tmp1 = ks2 + tmp2 = tmp0 < tmp1 + tmp3 = r0_3 + 128*x0 + tmp4 = ks3 + tmp5 = tmp3 < tmp4 + tmp6 = tmp2 & tmp5 + tmp7 = r0_4 + 128*x1 + tmp8 = r0_3 + 128*x0 + tmp9 = tmp7 >= tmp8 + tmp10 = tl.load(in_ptr0 + (tl.broadcast_to(x2, [XBLOCK, R0_BLOCK])), r0_mask & tmp6 & xmask, eviction_policy='evict_last', other=0.0) + tmp11 = tmp8 < tmp10 + tmp12 = tmp7 < tmp10 + tmp13 = tmp11 & tmp12 + tmp14 = tmp9 & tmp13 + tmp15 = tl.full([1, 1], False, tl.int1) + tmp16 = tmp15 | tmp14 + tmp17 = tl.broadcast_to(ks5, [XBLOCK, R0_BLOCK]) + tmp18 = tmp8 >= tmp17 + tmp19 = (tmp8 % tmp17) + tmp20 = tl.full([1, 1], 0, tl.int32) + tmp21 = tmp19 != tmp20 + tmp22 = (libdevice.signbit(tmp19) != 0) if (tmp19).dtype is tl.float32 else tmp19 < 0 + tmp23 = (libdevice.signbit(tmp17) != 0) if (tmp17).dtype is tl.float32 else tmp17 < 0 + tmp24 = tmp22 != tmp23 + tmp25 = tmp21 & tmp24 + tmp26 = tmp19 + tmp17 + tmp27 = tl.where(tmp25, tmp26, tmp19) + tmp28 = tmp27 < tmp10 + tmp29 = tmp18 & tmp28 + tmp30 = r0_3 + ((-1)*r0_4) + ((-128)*x1) + 128*x0 + tmp31 = (tmp30 % tmp17) + tmp32 = tmp31 != tmp20 + tmp33 = (libdevice.signbit(tmp31) != 0) if (tmp31).dtype is tl.float32 else tmp31 < 0 + tmp34 = tmp33 != tmp23 + tmp35 = tmp32 & tmp34 + tmp36 = tmp31 + tmp17 + tmp37 = tl.where(tmp35, tmp36, tmp31) + tmp38 = tl.full([1, 1], 0, tl.int64) + tmp39 = tmp37 == tmp38 + tmp40 = tmp29 & tmp39 + tmp41 = tmp16 | tmp40 + tmp42 = tl.full(tmp41.shape, False, tmp41.dtype) + tmp43 = tl.where(tmp6, tmp41, tmp42) + tmp44 = tmp43.to(tl.int64) + tmp45 = tl.broadcast_to(tmp44, [XBLOCK, R0_BLOCK]) + tmp47 = _tmp46 + tmp45 + _tmp46 = tl.where(r0_mask & xmask, tmp47, _tmp46) + tmp46 = tl.sum(_tmp46, 1)[:, None] + tmp48 = tl.full([1, 1], 0, tl.int64) + tmp49 = tmp46 > tmp48 + tmp50 = tl.full([1, 1], 16384, tl.int64) + tmp51 = tmp46 < tmp50 + tmp52 = tmp49 & tmp51 + tmp53 = tmp52.to(tl.int8) + tmp54 = tmp53.to(tl.int32) + tmp55 = tmp46 == tmp50 + tmp56 = tmp55.to(tl.int8) + tmp57 = tmp56.to(tl.int32) + tl.store(out_ptr1 + (x5), tmp54, xmask) + tl.store(out_ptr2 + (x5), tmp57, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/yx/cyxrygdxulahgoszpzyjmm5hgv3rfk3nc7jyhzgu437mmg6vobyi.py b/SpecForge-ext/cache/compiled_kernels/yx/cyxrygdxulahgoszpzyjmm5hgv3rfk3nc7jyhzgu437mmg6vobyi.py new file mode 100644 index 0000000000000000000000000000000000000000..95b1ce6be58b9ec137cbe196ff8c5f486749fd65 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/yx/cyxrygdxulahgoszpzyjmm5hgv3rfk3nc7jyhzgu437mmg6vobyi.py @@ -0,0 +1,1065 @@ +# AOT ID: ['9_backward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] +# Source node to ATen node mapping: +# Graph fragment: +# %getitem : Tensor "bf16[2, 32, 2048, 128][8388608, 128, 4096, 1]cuda:0" = PlaceHolder[target=getitem] +# %tangents_1 : Tensor "bf16[2, 32, 2048, 128][8388608, 262144, 128, 1]cuda:0" = PlaceHolder[target=tangents_1] +# %buf0 : Tensor "bf16[2, 32, 2048][65536, 2048, 1]cuda:0" = PlaceHolder[target=buf0] +# %full_default : Tensor "f32[2, 32, 2048][65536, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 32, 2048], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_1, %primals_3, %primals_5, %getitem, %getitem_1, %tangents_1, %full_default, %fw_graph0, %joint_graph0, (2048, %primals_8, %primals_9, %primals_7, %primals_11, %primals_13, %primals_15, %primals_17, %primals_19, %primals_21, 128, 128, %mask_graph0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_10,)), kwargs = {}) +# return %buf0,%buf1 +triton_red_fused_zeros_0 = async_compile.triton('triton_red_fused_zeros_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 131072, 'r0_': 128}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr1': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 1048576, 'r0_': 67108864}} +) +@triton.jit +def triton_red_fused_zeros_0(in_ptr0, in_ptr1, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 131072 + r0_numel = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % 2048) + x1 = ((xindex // 2048) % 32) + x2 = xindex // 65536 + x4 = xindex + _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_3 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_3 + 128*x1 + 4096*x0 + 8388608*x2), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_3 + 128*x4), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = _tmp4 + tmp3 + _tmp4 = tl.where(r0_mask, tmp5, _tmp4) + tmp4 = tl.sum(_tmp4, 1)[:, None] + tmp6 = tmp4.to(tl.float32) + tmp7 = 0.0 + tmp8 = tmp6 - tmp7 + tl.store(out_ptr1 + (x4), tmp8, None) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/vg/cvgwtgnvsowtrkc6lrj5h4cxx2xnzqdwjfgarwfheacxnbzxnzda.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] +# Source node to ATen node mapping: +# Graph fragment: +# %primals_1 : Tensor "bf16[2, 32, 2048, 128][8388608, 128, 4096, 1]cuda:0" = PlaceHolder[target=primals_1] +# %primals_3 : Tensor "bf16[2, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:0" = PlaceHolder[target=primals_3] +# %primals_5 : Tensor "bf16[2, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:0" = PlaceHolder[target=primals_5] +# %getitem_1 : Tensor "f32[2, 32, 2048][65536, 2048, 1]cuda:0" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[2, 32, 2048][65536, 2048, 1]cuda:0" = PlaceHolder[target=buf1] +# %tangents_1 : Tensor "bf16[2, 32, 2048, 128][8388608, 262144, 128, 1]cuda:0" = PlaceHolder[target=tangents_1] +# %getitem_3 : Tensor "bf16[2, 32, 2048, 128][8388608, 128, 4096, 1]cuda:0" = PlaceHolder[target=getitem_3] +# %getitem_5 : Tensor "bf16[2, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:0" = PlaceHolder[target=getitem_5] +# %primals_9 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:0" = PlaceHolder[target=primals_9] +# %primals_7 : Tensor "i32[2, 1, 16, s72][16*s72, 16*s72, s72, 1]cuda:0" = PlaceHolder[target=primals_7] +# %primals_15 : Tensor "i32[2, 1, s56][s56, s56, 1]cuda:0" = PlaceHolder[target=primals_15] +# %primals_17 : Tensor "i32[2, 1, s84, 16][16*s84, 16*s84, 16, 1]cuda:0" = PlaceHolder[target=primals_17] +# %primals_11 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:0" = PlaceHolder[target=primals_11] +# %primals_13 : Tensor "i32[2, 1, 16, s4][16*s4, 16*s4, s4, 1]cuda:0" = PlaceHolder[target=primals_13] +# %primals_19 : Tensor "i32[2, 1, s99][s99, s99, 1]cuda:0" = PlaceHolder[target=primals_19] +# %primals_21 : Tensor "i32[2, 1, s6, 16][16*s6, 16*s6, 16, 1]cuda:0" = PlaceHolder[target=primals_21] +# %primals_10 : Tensor "i64[2][1]cuda:0" = PlaceHolder[target=primals_10] +# %full_default : Tensor "f32[2, 32, 2048][65536, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 32, 2048], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_1, %primals_3, %primals_5, %getitem, %getitem_1, %tangents_1, %full_default, %fw_graph0, %joint_graph0, (2048, %primals_8, %primals_9, %primals_7, %primals_11, %primals_13, %primals_15, %primals_17, %primals_19, %primals_21, 128, 128, %mask_graph0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_10,)), kwargs = {}) +# return %getitem_4 +triton_tem_fused_zeros_1 = async_compile.triton('triton_tem_fused_zeros_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_zeros_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks0, 128*ks0, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks0, 128*ks0, 128, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 8388608, 262144, 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 8388608, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks0, 128*ks0, 128, 1 + + ZQ = 2 + HQ = 32 + HKV = 8 + Q_LEN = 2048 + ZKV = 2 + KV_LEN = ks0 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 2 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 16*ks1 + stride_kv_idx_m = ks1 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = ks2 + stride_q_idx_h = 16*ks3 + stride_q_idx_n = 16 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks0 + 1024*off_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = ks0 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr16 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp39 = (ds) + grad_scores = tmp39 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = ks0 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp40 = (qkT) + post_mod_scores = tmp40 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp41 = tl.full([1], False, tl.int1) + tmp42 = (m) + tmp43 = (n) + tmp44 = tmp42 >= tmp43 + tmp45 = tmp43.to(tl.int64) + tmp46 = (off_z) + tmp47 = tl.load(in_ptr16 + tmp46) + tmp48 = tmp45 < tmp47 + tmp49 = tmp42.to(tl.int64) + tmp50 = tmp49 < tmp47 + tmp51 = tmp48 & tmp50 + tmp52 = tmp44 & tmp51 + tmp53 = tmp41 | tmp52 + tmp54 = tl.full([1], 2048, tl.int32) + tmp55 = tmp43 >= tmp54 + tmp56 = (tmp43 % tmp54) + tmp57 = tl.full([1], 0, tl.int32) + tmp58 = tmp56 != tmp57 + tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0 + tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0 + tmp61 = tmp59 != tmp60 + tmp62 = tmp58 & tmp61 + tmp63 = tmp56 + tmp54 + tmp64 = tl.where(tmp62, tmp63, tmp56) + tmp65 = tmp64.to(tl.int64) + tmp66 = tmp65 < tmp47 + tmp67 = tmp55 & tmp66 + tmp68 = tmp43 - tmp42 + tmp69 = (tmp68 % tmp54) + tmp70 = tmp69 != tmp57 + tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0 + tmp72 = tmp71 != tmp60 + tmp73 = tmp70 & tmp72 + tmp74 = tmp69 + tmp54 + tmp75 = tl.where(tmp73, tmp74, tmp69) + tmp76 = tmp75 == tmp57 + tmp77 = tmp67 & tmp76 + tmp78 = tmp53 | tmp77 + mask_mod_output = tmp78 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp79 = (dsT) + grad_scores = tmp79 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_8, primals_6, primals_12, primals_14, primals_16, primals_18, primals_20, primals_1, primals_3, primals_5, primals_7, primals_9, primals_10, primals_11, primals_13, primals_15, primals_17, primals_19, primals_21, getitem, getitem_1, tangents_1 = args + args.clear() + s0 = primals_8 + s72 = primals_6 + s4 = primals_12 + s56 = primals_14 + s84 = primals_16 + s99 = primals_18 + s6 = primals_20 + assert_size_stride(primals_1, (2, 32, 2048, 128), (8388608, 128, 4096, 1)) + assert_size_stride(primals_3, (2, 8, s0, 128), (1024*s0, 128*s0, 128, 1)) + assert_size_stride(primals_5, (2, 8, s0, 128), (1024*s0, 128*s0, 128, 1)) + assert_size_stride(primals_7, (2, 1, 16, s72), (16*s72, 16*s72, s72, 1)) + assert_size_stride(primals_9, (2, 1, 16), (16, 16, 1)) + assert_size_stride(primals_10, (2, ), (1, )) + assert_size_stride(primals_11, (2, 1, 16), (16, 16, 1)) + assert_size_stride(primals_13, (2, 1, 16, s4), (16*s4, 16*s4, s4, 1)) + assert_size_stride(primals_15, (2, 1, s56), (s56, s56, 1)) + assert_size_stride(primals_17, (2, 1, s84, 16), (16*s84, 16*s84, 16, 1)) + assert_size_stride(primals_19, (2, 1, s99), (s99, s99, 1)) + assert_size_stride(primals_21, (2, 1, s6, 16), (16*s6, 16*s6, 16, 1)) + assert_size_stride(getitem, (2, 32, 2048, 128), (8388608, 128, 4096, 1)) + assert_size_stride(getitem_1, (2, 32, 2048), (65536, 2048, 1)) + assert_size_stride(tangents_1, (2, 32, 2048, 128), (8388608, 262144, 128, 1)) + with torch.cuda._DeviceGuard(0): + torch.cuda.set_device(0) + buf1 = empty_strided_cuda((2, 32, 2048), (65536, 2048, 1), torch.float32) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] + stream0 = get_raw_stream(0) + triton_red_fused_zeros_0.run(getitem, tangents_1, buf1, 131072, 128, stream=stream0) + del getitem + buf3 = empty_strided_cuda((2, 32, 2048, 128), (8388608, 128, 4096, 1), torch.bfloat16) + buf4 = empty_strided_cuda((2, 8, s0, 128), (1024*s0, 128*s0, 128, 1), torch.bfloat16) + buf5 = empty_strided_cuda((2, 8, s0, 128), (1024*s0, 128*s0, 128, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] + stream0 = get_raw_stream(0) + triton_tem_fused_zeros_1.run(primals_1, primals_3, primals_5, getitem_1, buf1, tangents_1, buf3, buf4, primals_9, primals_7, primals_15, primals_17, primals_11, primals_13, primals_19, primals_21, primals_10, buf5, s0, s72, s56, s84, 64 + ((127 + s0) // 128), 2, 8, stream=stream0) + del buf1 + del getitem_1 + del primals_1 + del primals_10 + del primals_11 + del primals_13 + del primals_15 + del primals_17 + del primals_19 + del primals_21 + del primals_3 + del primals_5 + del primals_7 + del primals_9 + del tangents_1 + return (buf3, None, buf5, None, buf4, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_8 = 4096 + primals_6 = 32 + primals_12 = 32 + primals_14 = 32 + primals_16 = 32 + primals_18 = 32 + primals_20 = 32 + primals_1 = rand_strided((2, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:0', dtype=torch.bfloat16) + primals_3 = rand_strided((2, 8, 4096, 128), (4194304, 524288, 128, 1), device='cuda:0', dtype=torch.bfloat16) + primals_5 = rand_strided((2, 8, 4096, 128), (4194304, 524288, 128, 1), device='cuda:0', dtype=torch.bfloat16) + primals_7 = rand_strided((2, 1, 16, 32), (512, 512, 32, 1), device='cuda:0', dtype=torch.int32) + primals_9 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:0', dtype=torch.int32) + primals_10 = rand_strided((2, ), (1, ), device='cuda:0', dtype=torch.int64) + primals_11 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:0', dtype=torch.int32) + primals_13 = rand_strided((2, 1, 16, 32), (512, 512, 32, 1), device='cuda:0', dtype=torch.int32) + primals_15 = rand_strided((2, 1, 32), (32, 32, 1), device='cuda:0', dtype=torch.int32) + primals_17 = rand_strided((2, 1, 32, 16), (512, 512, 16, 1), device='cuda:0', dtype=torch.int32) + primals_19 = rand_strided((2, 1, 32), (32, 32, 1), device='cuda:0', dtype=torch.int32) + primals_21 = rand_strided((2, 1, 32, 16), (512, 512, 16, 1), device='cuda:0', dtype=torch.int32) + getitem = rand_strided((2, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:0', dtype=torch.bfloat16) + getitem_1 = rand_strided((2, 32, 2048), (65536, 2048, 1), device='cuda:0', dtype=torch.float32) + tangents_1 = rand_strided((2, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16) + fn = lambda: call([primals_8, primals_6, primals_12, primals_14, primals_16, primals_18, primals_20, primals_1, primals_3, primals_5, primals_7, primals_9, primals_10, primals_11, primals_13, primals_15, primals_17, primals_19, primals_21, getitem, getitem_1, tangents_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/z4/90f0dec4e81bc24642b491872ca8b816f7e67ba488d9bac9f450603283865a49.best_config b/SpecForge-ext/cache/compiled_kernels/z4/90f0dec4e81bc24642b491872ca8b816f7e67ba488d9bac9f450603283865a49.best_config new file mode 100644 index 0000000000000000000000000000000000000000..f1c46524ae475f95a41419c3265ac06e5e818e68 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/z4/90f0dec4e81bc24642b491872ca8b816f7e67ba488d9bac9f450603283865a49.best_config @@ -0,0 +1 @@ +{"XBLOCK": 8, "R0_BLOCK": 128, "num_warps": 8, "num_stages": 1, "configs_hash": "48464ea7d171263ae4fed5184e32a30841f1081b8df295ec1f8e2f76e5287c9d", "found_by_coordesc": false, "time_taken_ms": 61, "triton_cache_hash": "EGDJYO36DUYGK3UQBUH6S7RMVKF77GGHWVMFFZR5R4TDMIZ4YVJA"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/zc/czc4uswzazabvj7ebt72gzrcg2fgrugi6d7lol4a4jino45fz2ua.py b/SpecForge-ext/cache/compiled_kernels/zc/czc4uswzazabvj7ebt72gzrcg2fgrugi6d7lol4a4jino45fz2ua.py new file mode 100644 index 0000000000000000000000000000000000000000..b972ee7182a51225177daa37563af7eda85e3f4d --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/zc/czc4uswzazabvj7ebt72gzrcg2fgrugi6d7lol4a4jino45fz2ua.py @@ -0,0 +1,72 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 512, 'r0_': 16384}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 8192, 'r0_': 0}} +) +@triton.jit +def triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 512 + r0_numel = 16384 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x1 = ((xindex // 16) % 16) + x0 = (xindex % 16) + x2 = xindex // 256 + tmp3 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last') + _tmp29 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + x6 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_4 = r0_index // 128 + r0_3 = (r0_index % 128) + tmp0 = r0_4 + 128*x1 + tmp1 = r0_3 + 128*x0 + tmp2 = tmp0 >= tmp1 + tmp4 = tmp1 < tmp3 + tmp5 = tmp0 < tmp3 + tmp6 = tmp4 & tmp5 + tmp7 = tmp2 & tmp6 + tmp8 = tl.full([1, 1], False, tl.int1) + tmp9 = tmp8 | tmp7 + tmp10 = tl.full([1, 1], 2048, tl.int64) + tmp11 = tmp1 >= tmp10 + tmp12 = tmp11 & tmp4 + tmp13 = r0_3 + ((-1)*r0_4) + ((-128)*x1) + 128*x0 + tmp14 = (tmp13 % tmp10) + tmp15 = tl.full([1, 1], 0, tl.int32) + tmp16 = tmp14 != tmp15 + tmp17 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp18 = (libdevice.signbit(tmp10) != 0) if (tmp10).dtype is tl.float32 else tmp10 < 0 + tmp19 = tmp17 != tmp18 + tmp20 = tmp16 & tmp19 + tmp21 = tmp14 + tmp10 + tmp22 = tl.where(tmp20, tmp21, tmp14) + tmp23 = tl.full([1, 1], 0, tl.int64) + tmp24 = tmp22 == tmp23 + tmp25 = tmp12 & tmp24 + tmp26 = tmp9 | tmp25 + tmp27 = tmp26.to(tl.int64) + tmp28 = tl.broadcast_to(tmp27, [XBLOCK, R0_BLOCK]) + tmp30 = _tmp29 + tmp28 + _tmp29 = tl.where(r0_mask & xmask, tmp30, _tmp29) + tmp29 = tl.sum(_tmp29, 1)[:, None] + tl.store(out_ptr0 + (x6), tmp29, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/zc/czcxaz2wd7idbagw7rl2vswtaqev2ykiiewnqw7rxlvbhdqeplzj.py b/SpecForge-ext/cache/compiled_kernels/zc/czcxaz2wd7idbagw7rl2vswtaqev2ykiiewnqw7rxlvbhdqeplzj.py new file mode 100644 index 0000000000000000000000000000000000000000..fa7f3594154d2875d95d4e2d279f56443f6668a6 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/zc/czcxaz2wd7idbagw7rl2vswtaqev2ykiiewnqw7rxlvbhdqeplzj.py @@ -0,0 +1,66 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 16777216}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 6, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = (xindex % ks0) + x3 = xindex + x1 = ((xindex // ks0) % ks1) + tmp31 = tl.load(in_ptr0 + (x3), xmask, eviction_policy='evict_last').to(tl.float32) + tmp32 = tl.load(in_ptr1 + (x1), xmask, eviction_policy='evict_last') + tmp0 = x0 + tmp1 = ks0 // 2 + tmp2 = tmp0 >= tmp1 + tmp3 = tl.load(in_ptr0 + (x3 + (-1)*(ks0 // 2)), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp4 = tl.load(in_ptr1 + (x1), tmp2 & xmask, eviction_policy='evict_last', other=0.0) + tmp5 = tl.broadcast_to(ks2, [XBLOCK]) + tmp6 = tmp4 + tmp5 + tmp7 = tmp4 < 0 + tmp8 = tl.where(tmp7, tmp6, tmp4) + tl.device_assert(((0 <= tl.broadcast_to(tmp8, [XBLOCK])) & (tl.broadcast_to(tmp8, [XBLOCK]) < ks2)) | ~(tmp2 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp8, [XBLOCK]) < ks2") + tmp10 = tl.load(in_ptr2 + (x0 + (-1)*(ks0 // 2) + ks0*tmp8), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp11 = tmp3 * tmp10 + tmp12 = -tmp11 + tmp13 = tl.full(tmp12.shape, 0.0, tmp12.dtype) + tmp14 = tl.where(tmp2, tmp12, tmp13) + tmp15 = 0.0 + tmp16 = tl.where(tmp2, tmp14, tmp15) + tmp17 = tmp0 < tmp1 + tmp18 = tl.load(in_ptr0 + (ks0 + x3 + (-1)*(ks0 // 2)), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp19 = tl.load(in_ptr1 + (x1), tmp17 & xmask, eviction_policy='evict_last', other=0.0) + tmp20 = tl.broadcast_to(ks2, [XBLOCK]) + tmp21 = tmp19 + tmp20 + tmp22 = tmp19 < 0 + tmp23 = tl.where(tmp22, tmp21, tmp19) + tl.device_assert(((0 <= tl.broadcast_to(tmp23, [XBLOCK])) & (tl.broadcast_to(tmp23, [XBLOCK]) < ks2)) | ~(tmp17 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp23, [XBLOCK]) < ks2") + tmp25 = tl.load(in_ptr2 + (ks0 + x0 + (-1)*(ks0 // 2) + ks0*tmp23), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp26 = tmp18 * tmp25 + tmp27 = tl.full(tmp26.shape, 0.0, tmp26.dtype) + tmp28 = tl.where(tmp17, tmp26, tmp27) + tmp29 = tl.where(tmp17, tmp28, tmp15) + tmp30 = tmp16 + tmp29 + tmp33 = ks3 + tmp34 = tmp32 + tmp33 + tmp35 = tmp32 < 0 + tmp36 = tl.where(tmp35, tmp34, tmp32) + tl.device_assert(((0 <= tmp36) & (tmp36 < ks3)) | ~(xmask), "index out of bounds: 0 <= tmp36 < ks3") + tmp38 = tl.load(in_ptr3 + (x0 + ks0*tmp36), xmask, eviction_policy='evict_last').to(tl.float32) + tmp39 = tmp31 * tmp38 + tmp40 = tmp30 + tmp39 + tl.store(out_ptr0 + (x3), tmp40, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/zc/f86c34752b94b735a0a16e425d3d40279cd0e9fdb4209e41f1d0242f91c34009.best_config b/SpecForge-ext/cache/compiled_kernels/zc/f86c34752b94b735a0a16e425d3d40279cd0e9fdb4209e41f1d0242f91c34009.best_config new file mode 100644 index 0000000000000000000000000000000000000000..cbf4eb5ae8826a07243c88f3ee991df371ea45fb --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/zc/f86c34752b94b735a0a16e425d3d40279cd0e9fdb4209e41f1d0242f91c34009.best_config @@ -0,0 +1 @@ +{"XBLOCK": 512, "num_warps": 8, "num_stages": 1, "configs_hash": "3ca5c3e34d35093f3c9ab2829a9faeebad5e61c4ca13d5ed6053d7b71ce60d5a", "found_by_coordesc": false, "time_taken_ms": 53, "triton_cache_hash": "UQSFYICF6CFQWZOBHCGZ7JZ457GHWVO6RMPN5ABNWOATFMKI6GQA"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/zs/czs5ajn54p3jyxmsxcfenxtcm3rwqng63ls3udjpktpl3vy352ky.py b/SpecForge-ext/cache/compiled_kernels/zs/czs5ajn54p3jyxmsxcfenxtcm3rwqng63ls3udjpktpl3vy352ky.py new file mode 100644 index 0000000000000000000000000000000000000000..5ed5569c7a9d7993b37cd1fe3ebdfd403f80b0cb --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/zs/czs5ajn54p3jyxmsxcfenxtcm3rwqng63ls3udjpktpl3vy352ky.py @@ -0,0 +1,552 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'in_ptr9': '*i64', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 2097152, 262144, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 2097152, 262144, 128, 1 + + ZQ = 8 + HQ = 32 + Q_LEN = 2048 + ZKV = 8 + KV_LEN = 2048 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 8 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 256 + stride_kv_idx_m = 16 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 262144*idx_hq + 8388608*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m + 8388608*idx_zq, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr9 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/zs/czsnkxezqf7g2xr3cg2xrd2oooeztnp7xu42mufa3alvnuph7gsy.py b/SpecForge-ext/cache/compiled_kernels/zs/czsnkxezqf7g2xr3cg2xrd2oooeztnp7xu42mufa3alvnuph7gsy.py new file mode 100644 index 0000000000000000000000000000000000000000..7098eb27a428c2e393e511e7d302573b2fa8c607 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/zs/czsnkxezqf7g2xr3cg2xrd2oooeztnp7xu42mufa3alvnuph7gsy.py @@ -0,0 +1,527 @@ +# AOT ID: ['8_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/us/cuss5perekriv5zlubnk52f3pej5qclzrmta7cvidhkpzhupmvt5.py +# Topologically Sorted Source Nodes: [dense_mask_2], Original ATen: [aten.new_zeros] +# Source node to ATen node mapping: +# dense_mask_2 => full_default_1 +# Graph fragment: +# %full_default_1 : Tensor "i32[2, 1, 16, (((s37 + 127)//128)) + 1][16*Max(1, (((s37 + 127)//128)) + 1), 16*Max(1, (((s37 + 127)//128)) + 1), Max(1, (((s37 + 127)//128)) + 1), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 1, 16, %add_166], 0), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:6, pin_memory: False}) +# return %index_put +triton_poi_fused_new_zeros_0 = async_compile.triton('triton_poi_fused_new_zeros_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 2048}, + filename=__file__, + triton_meta={'signature': {'out_ptr0': '*i32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_new_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_new_zeros_0(out_ptr0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.full([1], 0, tl.int32) + tl.store(out_ptr0 + (x0), tmp0, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bi/cbigeynkmamirzra5ocdek4vfe3idnh2kr2bfscbxtiim3rq5df5.py +# Topologically Sorted Source Nodes: [result_1, m, causal_mask, n, b, index, lt, padding_mask, index_1, lt_1, and_2, suffix_mask, remainder, index_2, padding_mask_1, and_3, and_4, sub, remainder_1, diagnol_mask, result_2, batched_outputs_2, mask_1, mask_2, mask_3, mask_block_sum, gt, lt_3, partial_blocks, partial_blocks_1, dense_mask, full_blocks, full_blocks_1, dense_mask_1], Original ATen: [aten.view, aten.arange, aten.ge, aten.index, aten.lt, aten.bitwise_and, aten.bitwise_or, aten.remainder, aten.sub, aten.eq, aten.constant_pad_nd, aten.permute, aten.sum, aten.gt, aten._to_copy] +# Source node to ATen node mapping: +# and_2 => bitwise_and_1 +# and_3 => bitwise_and_2 +# and_4 => bitwise_and_3, view_8 +# b => iota +# batched_outputs_2 => view_9 +# causal_mask => ge_1, view +# dense_mask => convert_element_type_2 +# dense_mask_1 => convert_element_type_5 +# diagnol_mask => eq_12 +# full_blocks => eq_24 +# full_blocks_1 => convert_element_type_1 +# gt => gt +# index => index +# index_1 => index_1 +# index_2 => index_2 +# lt => lt, view_1 +# lt_1 => lt_1, view_2 +# lt_3 => lt_3 +# m => iota_2 +# mask_1 => constant_pad_nd +# mask_2 => view_10 +# mask_3 => permute +# mask_block_sum => sum_1 +# n => iota_3 +# padding_mask => bitwise_and, view_3, view_4 +# padding_mask_1 => lt_2, view_6 +# partial_blocks => bitwise_and_4 +# partial_blocks_1 => convert_element_type +# remainder => remainder +# remainder_1 => remainder_1 +# result_1 => bitwise_or, full_default +# result_2 => bitwise_or_1 +# sub => sub_12, view_7 +# suffix_mask => ge_2 +# Graph fragment: +# %arg1_1 : Tensor "i64[2][1]cuda:6" = PlaceHolder[target=arg1_1] +# %sum_1 : Tensor "i64[2, 1, 16, ((s37 + 127)//128)][16*(((s37 + 127)//128)), 32*(((s37 + 127)//128)), ((s37 + 127)//128), 1]cuda:6" = PlaceHolder[target=sum_1] +# %full_default : Tensor "b8[2, 1, 1][1, 1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 1, 1], False), kwargs = {dtype: torch.bool, layout: torch.strided, device: cuda:6, pin_memory: False}) +# %iota_2 : Tensor "i64[2048][1]cuda:6"[num_users=3] = call_function[target=torch.ops.prims.iota.default](args = (2048,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:6, requires_grad: False}) +# %view : Tensor "i64[2048, 1][1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%iota_2, [2048, 1]), kwargs = {}) +# %iota_3 : Tensor "i64[s37][1]cuda:6"[num_users=5] = call_function[target=torch.ops.prims.iota.default](args = (%arg0_1,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:6, requires_grad: False}) +# %ge_1 : Tensor "b8[2048, s37][Max(1, s37), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.ge.Tensor](args = (%view, %iota_3), kwargs = {}) +# %iota : Tensor "i64[2][1]cuda:6"[num_users=3] = call_function[target=torch.ops.prims.iota.default](args = (2,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:6, requires_grad: False}) +# %index : Tensor "i64[2][1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg1_1, [%iota]), kwargs = {}) +# %view_1 : Tensor "i64[2, 1][1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%index, [2, 1]), kwargs = {}) +# %lt : Tensor "b8[2, s37][Max(1, s37), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_3, %view_1), kwargs = {}) +# %view_4 : Tensor "b8[2, 1, s37][Max(1, s37), s37, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%lt, [2, 1, %arg0_1]), kwargs = {}) +# %index_1 : Tensor "i64[2][1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg1_1, [%iota]), kwargs = {}) +# %view_2 : Tensor "i64[2, 1][1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%index_1, [2, 1]), kwargs = {}) +# %lt_1 : Tensor "b8[2, 2048][2048, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_2, %view_2), kwargs = {}) +# %view_3 : Tensor "b8[2, 2048, 1][2048, 1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%lt_1, [2, 2048, 1]), kwargs = {}) +# %bitwise_and : Tensor "b8[2, 2048, s37][2048*Max(1, s37), Max(1, s37), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%view_4, %view_3), kwargs = {}) +# %bitwise_and_1 : Tensor "b8[2, 2048, s37][2048*Max(1, s37), Max(1, s37), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%ge_1, %bitwise_and), kwargs = {}) +# %bitwise_or : Tensor "b8[2, 2048, s37][2048*Max(1, s37), Max(1, s37), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.bitwise_or.Tensor](args = (%full_default, %bitwise_and_1), kwargs = {}) +# %ge_2 : Tensor "b8[s37][1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.ge.Scalar](args = (%iota_3, 2048), kwargs = {}) +# %remainder : Tensor "i64[s37][1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.remainder.Scalar](args = (%iota_3, 2048), kwargs = {}) +# %index_2 : Tensor "i64[2][1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg1_1, [%iota]), kwargs = {}) +# %view_6 : Tensor "i64[2, 1][1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%index_2, [2, 1]), kwargs = {}) +# %lt_2 : Tensor "b8[2, s37][Max(1, s37), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%remainder, %view_6), kwargs = {}) +# %bitwise_and_2 : Tensor "b8[2, s37][Max(1, s37), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%ge_2, %lt_2), kwargs = {}) +# %view_8 : Tensor "b8[2, 1, s37][Max(1, s37), s37, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%bitwise_and_2, [2, 1, %arg0_1]), kwargs = {}) +# %view_7 : Tensor "i64[2048, 1][1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%iota_2, [2048, 1]), kwargs = {}) +# %sub_12 : Tensor "i64[2048, s37][Max(1, s37), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%iota_3, %view_7), kwargs = {}) +# %remainder_1 : Tensor "i64[2048, s37][Max(1, s37), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.remainder.Scalar](args = (%sub_12, 2048), kwargs = {}) +# %eq_12 : Tensor "b8[2048, s37][Max(1, s37), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%remainder_1, 0), kwargs = {}) +# %bitwise_and_3 : Tensor "b8[2, 2048, s37][2048*Max(1, s37), Max(1, s37), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%view_8, %eq_12), kwargs = {}) +# %bitwise_or_1 : Tensor "b8[2, 2048, s37][2048*Max(1, s37), Max(1, s37), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.bitwise_or.Tensor](args = (%bitwise_or, %bitwise_and_3), kwargs = {}) +# %view_9 : Tensor "b8[2, 1, 2048, s37][2048*Max(1, s37), 2048*Max(1, s37), Max(1, s37), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%bitwise_or_1, [2, 1, 2048, %arg0_1]), kwargs = {}) +# %constant_pad_nd : Tensor "b8[2, 1, 2048, 128*(((s37 + 127)//128))][2048*Max(1, 128*(((s37 + 127)//128))), 2048*Max(1, 128*(((s37 + 127)//128))), Max(1, 128*(((s37 + 127)//128))), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.constant_pad_nd.default](args = (%expand, [0, %sub_23, 0, 0], 0.0), kwargs = {}) +# %view_10 : Tensor "b8[2, 1, 16, 128, ((s37 + 127)//128), 128][2048*Max(1, 128*(((s37 + 127)//128))), 2048*Max(1, 128*(((s37 + 127)//128))), 128*Max(1, 128*(((s37 + 127)//128))), Max(1, 128*(((s37 + 127)//128))), 128, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%constant_pad_nd, [2, 1, 16, 128, %floordiv_1, 128]), kwargs = {}) +# %permute : Tensor "b8[2, 1, 16, ((s37 + 127)//128), 128, 128][2048*Max(1, 128*(((s37 + 127)//128))), 2048*Max(1, 128*(((s37 + 127)//128))), 128*Max(1, 128*(((s37 + 127)//128))), 128, Max(1, 128*(((s37 + 127)//128))), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_10, [0, 1, 2, 4, 3, 5]), kwargs = {}) +# %sum_1 : Tensor "i64[2, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:6"[num_users=3] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%permute, [-2, -1]), kwargs = {}) +# %gt : Tensor "b8[2, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {}) +# %lt_3 : Tensor "b8[2, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.lt.Scalar](args = (%sum_1, 16384), kwargs = {}) +# %bitwise_and_4 : Tensor "b8[2, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%gt, %lt_3), kwargs = {}) +# %convert_element_type : Tensor "i8[2, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%bitwise_and_4, torch.int8), kwargs = {}) +# %convert_element_type_2 : Tensor "i32[2, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:6"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%convert_element_type, torch.int32), kwargs = {}) +# %eq_24 : Tensor "b8[2, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%sum_1, 16384), kwargs = {}) +# %convert_element_type_1 : Tensor "i8[2, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%eq_24, torch.int8), kwargs = {}) +# %convert_element_type_5 : Tensor "i32[2, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:6"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%convert_element_type_1, torch.int32), kwargs = {}) +# return %sum_1,%convert_element_type_2,%convert_element_type_5 +triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1 = async_compile.triton('triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 1024, 'r0_': 16384}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr1': '*i32', 'out_ptr2': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1(in_ptr0, out_ptr1, out_ptr2, ks0, ks1, ks2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 16384 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % ks0) + x1 = ((xindex // ks0) % 16) + x2 = xindex // ks2 + _tmp36 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + x5 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_3 = (r0_index % 128) + r0_4 = r0_index // 128 + tmp0 = r0_3 + 128*x0 + tmp1 = ks1 + tmp2 = tmp0 < tmp1 + tmp3 = r0_4 + 128*x1 + tmp4 = r0_3 + 128*x0 + tmp5 = tmp3 >= tmp4 + tmp6 = tl.load(in_ptr0 + (tl.broadcast_to(x2, [XBLOCK, R0_BLOCK])), r0_mask & tmp2 & xmask, eviction_policy='evict_last', other=0.0) + tmp7 = tmp4 < tmp6 + tmp8 = tmp3 < tmp6 + tmp9 = tmp7 & tmp8 + tmp10 = tmp5 & tmp9 + tmp11 = tl.full([1, 1], False, tl.int1) + tmp12 = tmp11 | tmp10 + tmp13 = tl.full([1, 1], 2048, tl.int64) + tmp14 = tmp4 >= tmp13 + tmp15 = ((r0_3 + 128*x0) % 2048) + tmp16 = tmp15 < tmp6 + tmp17 = tmp14 & tmp16 + tmp18 = r0_3 + ((-1)*r0_4) + ((-128)*x1) + 128*x0 + tmp19 = (tmp18 % tmp13) + tmp20 = tl.full([1, 1], 0, tl.int32) + tmp21 = tmp19 != tmp20 + tmp22 = (libdevice.signbit(tmp19) != 0) if (tmp19).dtype is tl.float32 else tmp19 < 0 + tmp23 = (libdevice.signbit(tmp13) != 0) if (tmp13).dtype is tl.float32 else tmp13 < 0 + tmp24 = tmp22 != tmp23 + tmp25 = tmp21 & tmp24 + tmp26 = tmp19 + tmp13 + tmp27 = tl.where(tmp25, tmp26, tmp19) + tmp28 = tl.full([1, 1], 0, tl.int64) + tmp29 = tmp27 == tmp28 + tmp30 = tmp17 & tmp29 + tmp31 = tmp12 | tmp30 + tmp32 = tl.full(tmp31.shape, False, tmp31.dtype) + tmp33 = tl.where(tmp2, tmp31, tmp32) + tmp34 = tmp33.to(tl.int64) + tmp35 = tl.broadcast_to(tmp34, [XBLOCK, R0_BLOCK]) + tmp37 = _tmp36 + tmp35 + _tmp36 = tl.where(r0_mask & xmask, tmp37, _tmp36) + tmp36 = tl.sum(_tmp36, 1)[:, None] + tmp38 = tl.full([1, 1], 0, tl.int64) + tmp39 = tmp36 > tmp38 + tmp40 = tl.full([1, 1], 16384, tl.int64) + tmp41 = tmp36 < tmp40 + tmp42 = tmp39 & tmp41 + tmp43 = tmp42.to(tl.int8) + tmp44 = tmp43.to(tl.int32) + tmp45 = tmp36 == tmp40 + tmp46 = tmp45.to(tl.int8) + tmp47 = tmp46.to(tl.int32) + tl.store(out_ptr1 + (x5), tmp44, xmask) + tl.store(out_ptr2 + (x5), tmp47, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/pz/cpz4qtb54zohukrmcagdai5zuh3utgu3eax4ghwvhtdsekz2cclm.py +# Topologically Sorted Source Nodes: [dense_mask_2, setitem, arange_4, row_indices, col_range, num_blocks_in_row, child_3, unsqueeze_1, index_mask, child_4, valid_indices], Original ATen: [aten.new_zeros, aten.arange, aten.unsqueeze, aten.sum, aten._to_copy, aten.lt, aten.scalar_tensor, aten.where, aten.view, aten.index_put] +# Source node to ATen node mapping: +# arange_4 => iota_4 +# child_3 => convert_element_type_3 +# child_4 => convert_element_type_4 +# col_range => iota_5 +# dense_mask_2 => full_default_1 +# index_mask => lt_4 +# num_blocks_in_row => sum_2 +# row_indices => unsqueeze +# setitem => full_default_2, index_put, iota_6, iota_7, unsqueeze_2, unsqueeze_3, unsqueeze_4, unsqueeze_5, unsqueeze_6 +# unsqueeze_1 => unsqueeze_1 +# valid_indices => scalar_tensor, where +# Graph fragment: +# %convert_element_type_2 : Tensor "i32[2, 1, 16, ((s37 + 127)//128)][16*(((s37 + 127)//128)), 32*(((s37 + 127)//128)), ((s37 + 127)//128), 1]cuda:6" = PlaceHolder[target=convert_element_type_2] +# %sum_2 : Tensor "i64[2, 1, 16][16, 32, 1]cuda:6" = PlaceHolder[target=sum_2] +# %getitem_1 : Tensor "i64[2, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 32*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:6" = PlaceHolder[target=getitem_1] +# %convert_element_type_3 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:6" = PlaceHolder[target=convert_element_type_3] +# %convert_element_type_4 : Tensor "i32[2, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:6" = PlaceHolder[target=convert_element_type_4] +# %index_put : Tensor "i32[2, 1, 16, (((s37 + 127)//128)) + 1][16*(((s37 + 127)//128)) + 16, 16*(((s37 + 127)//128)) + 16, (((s37 + 127)//128)) + 1, 1]cuda:6" = PlaceHolder[target=index_put] +# %full_default_1 : Tensor "i32[2, 1, 16, (((s37 + 127)//128)) + 1][16*Max(1, (((s37 + 127)//128)) + 1), 16*Max(1, (((s37 + 127)//128)) + 1), Max(1, (((s37 + 127)//128)) + 1), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 1, 16, %add_166], 0), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:6, pin_memory: False}) +# %iota_7 : Tensor "i64[2][1]cuda:6"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (2,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:6, requires_grad: False}) +# %unsqueeze_4 : Tensor "i64[2, 1][1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_7, -1), kwargs = {}) +# %unsqueeze_5 : Tensor "i64[2, 1, 1][1, 1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_4, -1), kwargs = {}) +# %unsqueeze_6 : Tensor "i64[2, 1, 1, 1][1, 1, 1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_5, -1), kwargs = {}) +# %iota_6 : Tensor "i64[1][1]cuda:6"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (1,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:6, requires_grad: False}) +# %unsqueeze_2 : Tensor "i64[1, 1][1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_6, -1), kwargs = {}) +# %unsqueeze_3 : Tensor "i64[1, 1, 1][1, 1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_2, -1), kwargs = {}) +# %iota_4 : Tensor "i32[16][1]cuda:6"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (16,), kwargs = {start: 0, step: 1, dtype: torch.int32, device: cuda:6, requires_grad: False}) +# %unsqueeze : Tensor "i32[16, 1][1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_4, -1), kwargs = {}) +# %iota_5 : Tensor "i32[((s37 + 127)//128)][1]cuda:6"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (%floordiv_1,), kwargs = {start: 0, step: 1, dtype: torch.int32, device: cuda:6, requires_grad: False}) +# %sum_2 : Tensor "i64[2, 1, 16][16, 16, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%convert_element_type_2, [-1]), kwargs = {}) +# %convert_element_type_3 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:6"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_2, torch.int32), kwargs = {}) +# %unsqueeze_1 : Tensor "i32[2, 1, 16, 1][16, 16, 1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_3, 3), kwargs = {}) +# %lt_4 : Tensor "b8[2, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_5, %unsqueeze_1), kwargs = {}) +# %convert_element_type_4 : Tensor "i32[2, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:6"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_1, torch.int32), kwargs = {}) +# %scalar_tensor : Tensor "i32[][]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.scalar_tensor.default](args = (%floordiv_1,), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:6}) +# %where : Tensor "i32[2, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%lt_4, %convert_element_type_4, %scalar_tensor), kwargs = {}) +# %full_default_2 : Tensor "i32[2, 1, 1, 1][1, 1, 1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 1, 1, 1], 1), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:6, pin_memory: False}) +# %index_put : Tensor "i32[2, 1, 16, (((s37 + 127)//128)) + 1][16*Max(1, (((s37 + 127)//128)) + 1), 16*Max(1, (((s37 + 127)//128)) + 1), Max(1, (((s37 + 127)//128)) + 1), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.index_put_.default](args = (%full_default_1, [%unsqueeze_6, %unsqueeze_3, %unsqueeze, %where], %full_default_2), kwargs = {}) +# return %sum_2,%convert_element_type_3,%convert_element_type_4,%buf13 +triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2 = async_compile.triton('triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 32, 'r0_': 32}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'in_ptr1': '*i64', 'out_ptr1': '*i32', 'out_ptr2': '*i32', 'out_ptr3': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2', 'mutated_arg_names': ['out_ptr3'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2(in_ptr0, in_ptr1, out_ptr1, out_ptr2, out_ptr3, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 32 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp3 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp1 = tmp0.to(tl.int64) + tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) + tmp4 = _tmp3 + tmp2 + _tmp3 = tl.where(r0_mask & xmask, tmp4, _tmp3) + tmp3 = tl.sum(_tmp3, 1)[:, None] + tmp5 = tmp3.to(tl.int32) + tl.store(out_ptr1 + (x0), tmp5, xmask) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp6 = tl.load(in_ptr1 + (r0_1 + x0*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp7 = tmp6.to(tl.int32) + tmp8 = r0_1 + tmp9 = tmp8 < tmp5 + tmp10 = ks0 + tmp11 = tl.where(tmp9, tmp7, tmp10) + tmp12 = 1 + ks0 + tmp13 = tmp11 + tmp12 + tmp14 = tmp11 < 0 + tmp15 = tl.where(tmp14, tmp13, tmp11) + tl.device_assert(((0 <= tmp15) & (tmp15 < 1 + (triton_helpers.div_floor_integer(127 + ks1, 128)))) | ~(r0_mask & xmask), "index out of bounds: 0 <= tmp15 < 1 + (triton_helpers.div_floor_integer(127 + ks1, 128))") + tmp17 = tl.full([1, 1], 1, tl.int32) + tl.store(out_ptr2 + (r0_1 + x0*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp7, r0_mask & xmask) + tl.store(out_ptr3 + (tl.broadcast_to(tmp15 + x0 + ks0*x0, [XBLOCK, R0_BLOCK])), tmp17, r0_mask & xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/we/cwe54qdzud6xskhvjsubzqbviobtofi7rtla3cz3fvufmekfy4qf.py +# Topologically Sorted Source Nodes: [batched_outputs_3, transpose, col_indices_2, q_indices, num_blocks_in_row_2, q_num_blocks], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sort, aten._to_copy, aten.sum] +# Source node to ATen node mapping: +# batched_outputs_3 => clone_4, slice_2 +# col_indices_2 => sort_2 +# num_blocks_in_row_2 => sum_4 +# q_indices => clone_6, convert_element_type_9 +# q_num_blocks => convert_element_type_8 +# transpose => permute_1 +# Graph fragment: +# %buf13 : Tensor "i32[2, 1, 16, (((s37 + 127)//128)) + 1][16*(((s37 + 127)//128)) + 16, 16*(((s37 + 127)//128)) + 16, (((s37 + 127)//128)) + 1, 1]cuda:6" = PlaceHolder[target=buf13] +# %buf15 : Tensor "i16[2, 1, ((s37 + 127)//128), 16][16*(((s37 + 127)//128)), 32*(((s37 + 127)//128)), 16, 1]cuda:6" = PlaceHolder[target=buf15] +# %sum_4 : Tensor "i64[2, 1, ((s37 + 127)//128)][((s37 + 127)//128), 2*(((s37 + 127)//128)), 1]cuda:6" = PlaceHolder[target=sum_4] +# %slice_2 : Tensor "i32[2, 1, 16, ((s37 + 127)//128)][16*Max(1, (((s37 + 127)//128)) + 1), 16*Max(1, (((s37 + 127)//128)) + 1), Max(1, (((s37 + 127)//128)) + 1), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%index_put, 3, 0, %floordiv_1), kwargs = {}) +# %clone_4 : Tensor "i32[2, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%slice_2,), kwargs = {memory_format: torch.contiguous_format}) +# %permute_1 : Tensor "i32[2, 1, ((s37 + 127)//128), 16][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), 1, Max(1, ((s37 + 127)//128))]cuda:6"[num_users=2] = call_function[target=torch.ops.aten.permute.default](args = (%clone_4, [0, 1, 3, 2]), kwargs = {}) +# %sort_2 : [num_users=1] = call_function[target=torch.ops.aten.sort.stable](args = (%permute_1,), kwargs = {stable: True, descending: True}) +# %convert_element_type_9 : Tensor "i32[2, 1, ((s37 + 127)//128), 16][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), 1, Max(1, ((s37 + 127)//128))]cuda:6"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_5, torch.int32), kwargs = {}) +# %clone_6 : Tensor "i32[2, 1, ((s37 + 127)//128), 16][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), 16, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%convert_element_type_9,), kwargs = {memory_format: torch.contiguous_format}) +# %sum_4 : Tensor "i64[2, 1, ((s37 + 127)//128)][Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%permute_1, [-1]), kwargs = {}) +# %convert_element_type_8 : Tensor "i32[2, 1, ((s37 + 127)//128)][Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_4, torch.int32), kwargs = {}) +# return %buf15,%sum_4,%clone_6,%convert_element_type_8 +triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3 = async_compile.triton('triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 64, 'r0_': 16}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr2': '*i32', 'out_ptr3': '*i32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3(in_ptr0, out_ptr2, out_ptr3, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 16 + R0_BLOCK: tl.constexpr = 16 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + x0 = (xindex % ks0) + x1 = xindex // ks0 + x3 = xindex + tmp0 = tl.load(in_ptr0 + (r0_2 + x0 + 16*x1 + ks0*r0_2 + 16*ks0*x1), xmask, eviction_policy='evict_last', other=0.0) + tmp1 = r0_2 + tmp2 = tmp1.to(tl.int16) + tmp3 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp4 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5, tmp6, = triton_helpers.sort_with_index(tmp3, tmp4, None, 1, stable=True, descending=True) + tmp7 = tmp0.to(tl.int64) + tmp8 = tl.broadcast_to(tmp7, [XBLOCK, R0_BLOCK]) + tmp10 = tl.where(xmask, tmp8, 0) + tmp11 = tl.sum(tmp10, 1)[:, None].to(tl.int64) + tmp12 = tmp6.to(tl.int64) + tmp13 = tmp12.to(tl.int32) + tmp14 = tmp11.to(tl.int32) + tl.store(out_ptr2 + (r0_2 + 16*x0 + 16*x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp13, xmask) + tl.store(out_ptr3 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp14, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1 = args + args.clear() + s37 = arg0_1 + assert_size_stride(arg1_1, (2, ), (1, )) + with torch.cuda._DeviceGuard(6): + torch.cuda.set_device(6) + buf12 = empty_strided_cuda((2, 1, 16, 1 + ((127 + s37) // 128)), (16 + 16*((127 + s37) // 128), 16 + 16*((127 + s37) // 128), 1 + ((127 + s37) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [dense_mask_2], Original ATen: [aten.new_zeros] + triton_poi_fused_new_zeros_0_xnumel = 32 + 32*((127 + s37) // 128) + stream6 = get_raw_stream(6) + triton_poi_fused_new_zeros_0.run(buf12, triton_poi_fused_new_zeros_0_xnumel, stream=stream6) + buf19 = empty_strided_cuda((2, 1, 16, 1 + ((127 + s37) // 128)), (16 + 16*((127 + s37) // 128), 16 + 16*((127 + s37) // 128), 1 + ((127 + s37) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [dense_mask_4], Original ATen: [aten.new_zeros] + triton_poi_fused_new_zeros_0_xnumel = 32 + 32*((127 + s37) // 128) + stream6 = get_raw_stream(6) + triton_poi_fused_new_zeros_0.run(buf19, triton_poi_fused_new_zeros_0_xnumel, stream=stream6) + ps0 = (127 + s37) // 128 + ps1 = 16*((127 + s37) // 128) + buf1 = empty_strided_cuda((2, 1, 16, (127 + s37) // 128), (16*((127 + s37) // 128), 32*((127 + s37) // 128), (127 + s37) // 128, 1), torch.int32) + buf5 = empty_strided_cuda((2, 1, 16, (127 + s37) // 128), (16*((127 + s37) // 128), 32*((127 + s37) // 128), (127 + s37) // 128, 1), torch.int32) + # Topologically Sorted Source Nodes: [result_1, m, causal_mask, n, b, index, lt, padding_mask, index_1, lt_1, and_2, suffix_mask, remainder, index_2, padding_mask_1, and_3, and_4, sub, remainder_1, diagnol_mask, result_2, batched_outputs_2, mask_1, mask_2, mask_3, mask_block_sum, gt, lt_3, partial_blocks, partial_blocks_1, dense_mask, full_blocks, full_blocks_1, dense_mask_1], Original ATen: [aten.view, aten.arange, aten.ge, aten.index, aten.lt, aten.bitwise_and, aten.bitwise_or, aten.remainder, aten.sub, aten.eq, aten.constant_pad_nd, aten.permute, aten.sum, aten.gt, aten._to_copy] + triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1_xnumel = 32*((127 + s37) // 128) + stream6 = get_raw_stream(6) + triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1.run(arg1_1, buf1, buf5, ps0, s37, ps1, triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1_xnumel, 16384, stream=stream6) + del arg1_1 + # Topologically Sorted Source Nodes: [gt, lt_3, partial_blocks, partial_blocks_1, dense_mask, col_indices], Original ATen: [aten.gt, aten.lt, aten.bitwise_and, aten._to_copy, aten.sort] + buf2 = torch.ops.aten.sort.stable(buf1, stable=True, dim=3, descending=True) + buf4 = buf2[1] + assert_size_stride(buf4, (2, 1, 16, (127 + s37) // 128), (16*max(1, (127 + s37) // 128), 32*max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), 'torch.ops.aten.sort.stable') + assert_alignment(buf4, 16, 'torch.ops.aten.sort.stable') + del buf2 + buf10 = empty_strided_cuda((2, 1, 16), (16, 16, 1), torch.int32) + buf11 = empty_strided_cuda((2, 1, 16, (127 + s37) // 128), (16*max(1, (127 + s37) // 128), 16*max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [dense_mask_2, setitem, arange_4, row_indices, col_range, num_blocks_in_row, child_3, unsqueeze_1, index_mask, child_4, valid_indices], Original ATen: [aten.new_zeros, aten.arange, aten.unsqueeze, aten.sum, aten._to_copy, aten.lt, aten.scalar_tensor, aten.where, aten.view, aten.index_put] + triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2_r0_numel = (127 + s37) // 128 + stream6 = get_raw_stream(6) + triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2.run(buf1, buf4, buf10, buf11, buf12, ps0, s37, 32, triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2_r0_numel, stream=stream6) + del buf1 + del buf4 + buf26 = empty_strided_cuda((2, 1, (127 + s37) // 128, 16), (16*max(1, (127 + s37) // 128), 16*max(1, (127 + s37) // 128), 16, 1), torch.int32) + buf28 = empty_strided_cuda((2, 1, (127 + s37) // 128), (max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [batched_outputs_3, transpose, col_indices_2, q_indices, num_blocks_in_row_2, q_num_blocks], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sort, aten._to_copy, aten.sum] + triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3_xnumel = 2*((127 + s37) // 128) + stream6 = get_raw_stream(6) + triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3.run(buf12, buf26, buf28, ps0, triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3_xnumel, 16, stream=stream6) + del buf12 + # Topologically Sorted Source Nodes: [full_blocks, full_blocks_1, dense_mask_1, col_indices_1], Original ATen: [aten.eq, aten._to_copy, aten.sort] + buf6 = torch.ops.aten.sort.stable(buf5, stable=True, dim=3, descending=True) + buf8 = buf6[1] + assert_size_stride(buf8, (2, 1, 16, (127 + s37) // 128), (16*max(1, (127 + s37) // 128), 32*max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), 'torch.ops.aten.sort.stable') + assert_alignment(buf8, 16, 'torch.ops.aten.sort.stable') + del buf6 + buf17 = empty_strided_cuda((2, 1, 16), (16, 16, 1), torch.int32) + buf18 = empty_strided_cuda((2, 1, 16, (127 + s37) // 128), (16*max(1, (127 + s37) // 128), 16*max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [dense_mask_4, setitem_1, arange_6, row_indices_1, col_range_1, num_blocks_in_row_1, child_7, unsqueeze_3, index_mask_1, child_8, valid_indices_1], Original ATen: [aten.new_zeros, aten.arange, aten.unsqueeze, aten.sum, aten._to_copy, aten.lt, aten.scalar_tensor, aten.where, aten.view, aten.index_put] + triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2_r0_numel = (127 + s37) // 128 + stream6 = get_raw_stream(6) + triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2.run(buf5, buf8, buf17, buf18, buf19, ps0, s37, 32, triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2_r0_numel, stream=stream6) + del buf5 + del buf8 + buf23 = empty_strided_cuda((2, 1, (127 + s37) // 128, 16), (16*max(1, (127 + s37) // 128), 16*max(1, (127 + s37) // 128), 16, 1), torch.int32) + buf25 = empty_strided_cuda((2, 1, (127 + s37) // 128), (max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [batched_outputs_5, transpose_1, col_indices_3, full_q_indices, num_blocks_in_row_3, full_q_num_blocks], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sort, aten._to_copy, aten.sum] + triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3_xnumel = 2*((127 + s37) // 128) + stream6 = get_raw_stream(6) + triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3.run(buf19, buf23, buf25, ps0, triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3_xnumel, 16, stream=stream6) + del buf19 + return (buf23, buf25, buf26, buf28, buf18, buf17, buf11, buf10, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = 4096 + arg1_1 = rand_strided((2, ), (1, ), device='cuda:6', dtype=torch.int64) + fn = lambda: call([arg0_1, arg1_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/zu/3354266b18c8d5c19fdbacd1e43605895ef40e6d2fc5778a4d2c00cb64dc377c.best_config b/SpecForge-ext/cache/compiled_kernels/zu/3354266b18c8d5c19fdbacd1e43605895ef40e6d2fc5778a4d2c00cb64dc377c.best_config new file mode 100644 index 0000000000000000000000000000000000000000..b52ad636ead7f0881602759bddab2da5b43669a1 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/zu/3354266b18c8d5c19fdbacd1e43605895ef40e6d2fc5778a4d2c00cb64dc377c.best_config @@ -0,0 +1 @@ +{"XBLOCK": 512, "num_warps": 8, "num_stages": 1, "configs_hash": "3ca5c3e34d35093f3c9ab2829a9faeebad5e61c4ca13d5ed6053d7b71ce60d5a", "found_by_coordesc": false, "time_taken_ms": 55, "triton_cache_hash": "XRR2QXTZQK4DSBTDJUTNXO6FEFXI2IIRKSC5GYSBWLTL56SKI4WA"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/zu/czu2jyesrdsgfrod6l7j2iof2pn657e57odk5qfyk2zi2uaqndjj.py b/SpecForge-ext/cache/compiled_kernels/zu/czu2jyesrdsgfrod6l7j2iof2pn657e57odk5qfyk2zi2uaqndjj.py new file mode 100644 index 0000000000000000000000000000000000000000..bb16de8fa20d0fd990250119d1f795548ce8f283 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/zu/czu2jyesrdsgfrod6l7j2iof2pn657e57odk5qfyk2zi2uaqndjj.py @@ -0,0 +1,66 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 16777216}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 6, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = (xindex % ks0) + x3 = xindex + x1 = ((xindex // ks0) % ks1) + tmp31 = tl.load(in_ptr0 + (x3), xmask, eviction_policy='evict_last').to(tl.float32) + tmp32 = tl.load(in_ptr1 + (x1), xmask, eviction_policy='evict_last') + tmp0 = x0 + tmp1 = ks0 // 2 + tmp2 = tmp0 >= tmp1 + tmp3 = tl.load(in_ptr0 + (x3 + (-1)*(ks0 // 2)), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp4 = tl.load(in_ptr1 + (x1), tmp2 & xmask, eviction_policy='evict_last', other=0.0) + tmp5 = tl.broadcast_to(ks2, [XBLOCK]) + tmp6 = tmp4 + tmp5 + tmp7 = tmp4 < 0 + tmp8 = tl.where(tmp7, tmp6, tmp4) + tl.device_assert(((0 <= tl.broadcast_to(tmp8, [XBLOCK])) & (tl.broadcast_to(tmp8, [XBLOCK]) < ks2)) | ~(tmp2 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp8, [XBLOCK]) < ks2") + tmp10 = tl.load(in_ptr2 + (x0 + (-1)*(ks0 // 2) + ks0*tmp8), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp11 = tmp3 * tmp10 + tmp12 = -tmp11 + tmp13 = tl.full(tmp12.shape, 0.0, tmp12.dtype) + tmp14 = tl.where(tmp2, tmp12, tmp13) + tmp15 = 0.0 + tmp16 = tl.where(tmp2, tmp14, tmp15) + tmp17 = tmp0 < tmp1 + tmp18 = tl.load(in_ptr0 + (ks0 + x3 + (-1)*(ks0 // 2)), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp19 = tl.load(in_ptr1 + (x1), tmp17 & xmask, eviction_policy='evict_last', other=0.0) + tmp20 = tl.broadcast_to(ks2, [XBLOCK]) + tmp21 = tmp19 + tmp20 + tmp22 = tmp19 < 0 + tmp23 = tl.where(tmp22, tmp21, tmp19) + tl.device_assert(((0 <= tl.broadcast_to(tmp23, [XBLOCK])) & (tl.broadcast_to(tmp23, [XBLOCK]) < ks2)) | ~(tmp17 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp23, [XBLOCK]) < ks2") + tmp25 = tl.load(in_ptr2 + (ks0 + x0 + (-1)*(ks0 // 2) + ks0*tmp23), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp26 = tmp18 * tmp25 + tmp27 = tl.full(tmp26.shape, 0.0, tmp26.dtype) + tmp28 = tl.where(tmp17, tmp26, tmp27) + tmp29 = tl.where(tmp17, tmp28, tmp15) + tmp30 = tmp16 + tmp29 + tmp33 = ks3 + tmp34 = tmp32 + tmp33 + tmp35 = tmp32 < 0 + tmp36 = tl.where(tmp35, tmp34, tmp32) + tl.device_assert(((0 <= tmp36) & (tmp36 < ks3)) | ~(xmask), "index out of bounds: 0 <= tmp36 < ks3") + tmp38 = tl.load(in_ptr3 + (x0 + ks0*tmp36), xmask, eviction_policy='evict_last').to(tl.float32) + tmp39 = tmp31 * tmp38 + tmp40 = tmp30 + tmp39 + tl.store(out_ptr0 + (x3), tmp40, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/zu/czupkfdsgvzkkkyrmre5slwdxod32ccb5eacvhg2ud5wd2ypvoq2.py b/SpecForge-ext/cache/compiled_kernels/zu/czupkfdsgvzkkkyrmre5slwdxod32ccb5eacvhg2ud5wd2ypvoq2.py new file mode 100644 index 0000000000000000000000000000000000000000..65ee824b9a81f34af6133d3ed2b8e9f6328770fc --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/zu/czupkfdsgvzkkkyrmre5slwdxod32ccb5eacvhg2ud5wd2ypvoq2.py @@ -0,0 +1,66 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 16777216}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 6, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = (xindex % ks0) + x3 = xindex + x1 = ((xindex // ks0) % ks1) + tmp31 = tl.load(in_ptr0 + (x3), xmask, eviction_policy='evict_last').to(tl.float32) + tmp32 = tl.load(in_ptr1 + (x1), xmask, eviction_policy='evict_last') + tmp0 = x0 + tmp1 = ks0 // 2 + tmp2 = tmp0 >= tmp1 + tmp3 = tl.load(in_ptr0 + (x3 + (-1)*(ks0 // 2)), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp4 = tl.load(in_ptr1 + (x1), tmp2 & xmask, eviction_policy='evict_last', other=0.0) + tmp5 = tl.broadcast_to(ks2, [XBLOCK]) + tmp6 = tmp4 + tmp5 + tmp7 = tmp4 < 0 + tmp8 = tl.where(tmp7, tmp6, tmp4) + tl.device_assert(((0 <= tl.broadcast_to(tmp8, [XBLOCK])) & (tl.broadcast_to(tmp8, [XBLOCK]) < ks2)) | ~(tmp2 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp8, [XBLOCK]) < ks2") + tmp10 = tl.load(in_ptr2 + (x0 + (-1)*(ks0 // 2) + ks0*tmp8), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp11 = tmp3 * tmp10 + tmp12 = -tmp11 + tmp13 = tl.full(tmp12.shape, 0.0, tmp12.dtype) + tmp14 = tl.where(tmp2, tmp12, tmp13) + tmp15 = 0.0 + tmp16 = tl.where(tmp2, tmp14, tmp15) + tmp17 = tmp0 < tmp1 + tmp18 = tl.load(in_ptr0 + (ks0 + x3 + (-1)*(ks0 // 2)), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp19 = tl.load(in_ptr1 + (x1), tmp17 & xmask, eviction_policy='evict_last', other=0.0) + tmp20 = tl.broadcast_to(ks2, [XBLOCK]) + tmp21 = tmp19 + tmp20 + tmp22 = tmp19 < 0 + tmp23 = tl.where(tmp22, tmp21, tmp19) + tl.device_assert(((0 <= tl.broadcast_to(tmp23, [XBLOCK])) & (tl.broadcast_to(tmp23, [XBLOCK]) < ks2)) | ~(tmp17 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp23, [XBLOCK]) < ks2") + tmp25 = tl.load(in_ptr2 + (ks0 + x0 + (-1)*(ks0 // 2) + ks0*tmp23), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp26 = tmp18 * tmp25 + tmp27 = tl.full(tmp26.shape, 0.0, tmp26.dtype) + tmp28 = tl.where(tmp17, tmp26, tmp27) + tmp29 = tl.where(tmp17, tmp28, tmp15) + tmp30 = tmp16 + tmp29 + tmp33 = ks3 + tmp34 = tmp32 + tmp33 + tmp35 = tmp32 < 0 + tmp36 = tl.where(tmp35, tmp34, tmp32) + tl.device_assert(((0 <= tmp36) & (tmp36 < ks3)) | ~(xmask), "index out of bounds: 0 <= tmp36 < ks3") + tmp38 = tl.load(in_ptr3 + (x0 + ks0*tmp36), xmask, eviction_policy='evict_last').to(tl.float32) + tmp39 = tmp31 * tmp38 + tmp40 = tmp30 + tmp39 + tl.store(out_ptr0 + (x3), tmp40, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/zu/e35c8f053fea1b3ce28a483bb59bad6cc21697a20f4bee5b6e815fa402c804c4.best_config b/SpecForge-ext/cache/compiled_kernels/zu/e35c8f053fea1b3ce28a483bb59bad6cc21697a20f4bee5b6e815fa402c804c4.best_config new file mode 100644 index 0000000000000000000000000000000000000000..cbf4eb5ae8826a07243c88f3ee991df371ea45fb --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/zu/e35c8f053fea1b3ce28a483bb59bad6cc21697a20f4bee5b6e815fa402c804c4.best_config @@ -0,0 +1 @@ +{"XBLOCK": 512, "num_warps": 8, "num_stages": 1, "configs_hash": "3ca5c3e34d35093f3c9ab2829a9faeebad5e61c4ca13d5ed6053d7b71ce60d5a", "found_by_coordesc": false, "time_taken_ms": 53, "triton_cache_hash": "UQSFYICF6CFQWZOBHCGZ7JZ457GHWVO6RMPN5ABNWOATFMKI6GQA"} \ No newline at end of file