Lekr0 commited on
Commit
80d5b8b
·
verified ·
1 Parent(s): eae7bce

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. SpecForge-ext/cache/compiled_kernels/2j/c2jyvidugg4t2zvjimwjrb4yacpc5zz5qifflapqv3x2b34cxuq7.py +56 -0
  2. SpecForge-ext/cache/compiled_kernels/2w/c2w4lzmy7ekz2bm6dysknbimkeh24xjafkehykmvy5slelzkuz4g.py +161 -0
  3. SpecForge-ext/cache/compiled_kernels/37/c37hhpqpo6nmqnyehjfjbhe3e5gtvdqkytyzohipcst7jy7ol4iy.py +675 -0
  4. SpecForge-ext/cache/compiled_kernels/3t/c3tsudqpzzym4mczuvujkiocbjfkpeu64fxstxm4zhnfoe575tz5.py +52 -0
  5. SpecForge-ext/cache/compiled_kernels/4a/c4al3iqkp6tuwbxkuvfdhsroylad5b7vhjzbcuo4bmvqojqx55a6.py +37 -0
  6. SpecForge-ext/cache/compiled_kernels/4f/454d8d353d28ad90c99c8953cfbd86dfbda71629c2e83398709dc784450ea2cc.best_config +1 -0
  7. SpecForge-ext/cache/compiled_kernels/4f/c4ft2b47ctfnp5zp5apvq5kvdlqubdrkzxpqndsh5oasyfr4v7y7.py +50 -0
  8. SpecForge-ext/cache/compiled_kernels/4f/c4ftkcyg442lwmtmm6lclyxflgi5xjez7jaopr447jjiva2hmpax.py +161 -0
  9. SpecForge-ext/cache/compiled_kernels/4f/c4fwwpijdyl5egtippb7rggm43z2kiggh4onk7xkd7o5v7vfl3c7.py +1051 -0
  10. SpecForge-ext/cache/compiled_kernels/4i/9b9fb3b21587241e4ad8c181607f493e81c755cfbd40bac95f98eae271b2754d.best_config +1 -0
  11. SpecForge-ext/cache/compiled_kernels/4i/c4iwnhsf5kfmm7jnzrkyiv4x3yahjog6dyhf4prm2cjdi5xhllx2.py +63 -0
  12. SpecForge-ext/cache/compiled_kernels/4l/c4lbz3jtnjjxbp7lftpjy4iam6ao6fc5cpp42bxihe27bm4qlhss.py +44 -0
  13. SpecForge-ext/cache/compiled_kernels/4n/a4add0613c3c13d6644e27d4d0641afe951924b14998f7667d2b2ebdefe532f7.best_config +1 -0
  14. SpecForge-ext/cache/compiled_kernels/4n/c4ntlraqki6522y3kmq7crnap6gq5asdu5huu7r2d7hvfkgash6w.py +25 -0
  15. SpecForge-ext/cache/compiled_kernels/4v/c4v5ovh2xgazpxywsn665wlhmrlaz6snvnzzmii7gxagr7rjrhrr.py +552 -0
  16. SpecForge-ext/cache/compiled_kernels/4y/c4yua3qi2b3xk6rn6ls5sdrsrpavp4zes7z62ki32y5ijfhzw4bb.py +552 -0
  17. SpecForge-ext/cache/compiled_kernels/6f/ba9cb84a5b5ef82fddf7d6be536aa0e0768988ffdd80996052da5fb28f5bfff3.best_config +1 -0
  18. SpecForge-ext/cache/compiled_kernels/6n/c6njycmp52a4ww57u7ir3n6hwhaktjczce3zzyrhirlmhjbkrrhg.py +693 -0
  19. SpecForge-ext/cache/compiled_kernels/7k/c7kogmtwjpemxq6qqxi6bohljmze6cjf34eo47hpufuxmpjep3yw.py +320 -0
  20. SpecForge-ext/cache/compiled_kernels/7p/c7ph4dk7ghsg37h7a46klnkhb6rck4rpgxyqg7fjyewxnxqk5vvs.py +46 -0
  21. SpecForge-ext/cache/compiled_kernels/ag/caglk6whzazaqxxtfwcwjz3xhkspqbhu4cpbiwsvmmwxpmmmtst6.py +161 -0
  22. SpecForge-ext/cache/compiled_kernels/ao/caoqvgzvbk7exhnvkuijsznlx2ebywfk6vitynyaomz5hgx5szk5.py +62 -0
  23. SpecForge-ext/cache/compiled_kernels/aw/cawxo2ohlu2xus3es5wun6g3qdjlbckp23dho2fo6p76pf7ogcso.py +322 -0
  24. SpecForge-ext/cache/compiled_kernels/c4/3fc868fcdc136a60cbcdc167284005fb6cd4078af5cf939debad2799d55dedad.best_config +1 -0
  25. SpecForge-ext/cache/compiled_kernels/c4/cc44tmaxtaxohkbf52w5omwmrxhrmn6iuplipagv7rlnxaz6dkey.py +552 -0
  26. SpecForge-ext/cache/compiled_kernels/c4/cc4r2l3x4dfli5iih5dji2abfxoclfozqdaqfbdxtcf6lqfpqwdo.py +49 -0
  27. SpecForge-ext/cache/compiled_kernels/cm/ccmqky4m65yifqjmfuu7vgvpuhwpa4ybaxffiy3mu2e6yzgecghe.py +25 -0
  28. SpecForge-ext/cache/compiled_kernels/dd/cddrh2oo46t7tins6cvtu23g2titlwclg4aile7eli326p7we42m.py +161 -0
  29. SpecForge-ext/cache/compiled_kernels/dl/b1f7dcc79c7c02fa44a9647ad7a02640f8312b36f97c27e92cc10dbab8e47d63.best_config +1 -0
  30. SpecForge-ext/cache/compiled_kernels/dl/cdlmoxz5rmtmnvhkkdtgykahwdzntxp2vrhxdea2s6finrwqdeut.py +86 -0
  31. SpecForge-ext/cache/compiled_kernels/do/cdoarqsgem4ej5qjlp6zd22rf6fimpoonczzpmfv63um26txbfab.py +168 -0
  32. SpecForge-ext/cache/compiled_kernels/dq/bbb4d7862e75b16b3f47ca1a7d19d9cb4b2d5337c27f7396cb01891263c9b13a.best_config +1 -0
  33. SpecForge-ext/cache/compiled_kernels/dq/cdq6jyounnaz2w4x6s5oljefpge3fzx66pi3x25iwcuc6vazkfx6.py +49 -0
  34. SpecForge-ext/cache/compiled_kernels/dq/cdqxxevdyssoyut2euw55y27cahqqcgmvyuhdihb4tmner7cfc7f.py +49 -0
  35. SpecForge-ext/cache/compiled_kernels/dq/e6aa9461d93df8973681493d15479cff1a0d8302c7a7de253f84ade82cf09c3e.best_config +1 -0
  36. SpecForge-ext/cache/compiled_kernels/dt/cdthlbsdpcqgxus7ldvwk23vvgojrmkgt7yidbhj27c2esjsap6w.py +164 -0
  37. SpecForge-ext/cache/compiled_kernels/dt/cdtjh6gxoepiahz2caz7vmm66wc5rf2ib5iyvtxe3w7pr44tvvpt.py +1051 -0
  38. SpecForge-ext/cache/compiled_kernels/dw/cdwf7pztwx35f2ysnyf6io3giyljdt7efoxairyx6so6kpwdnnl2.py +835 -0
  39. SpecForge-ext/cache/compiled_kernels/dw/cdwxivilyaij5fi345sh6qe7kemmtker7fznljyr22uuhwbwlgsx.py +675 -0
  40. SpecForge-ext/cache/compiled_kernels/e6/ce6g3e5xikzaf3a5wmxill5os7magq3p3hzz7uw37za4jjui6tk6.py +552 -0
  41. SpecForge-ext/cache/compiled_kernels/e6/ce6sgne5yx3pyeim455xwwbqvpu2da3rro3rzyopm3res7mhkspf.py +835 -0
  42. SpecForge-ext/cache/compiled_kernels/f6/cf6ayxqoma6zlumium5vkfjxneuep3h7lxmtssd73sg7bynrgpyn.py +552 -0
  43. SpecForge-ext/cache/compiled_kernels/fh/cfhmsnuqfbjggcp2r4forretj7wzvobbq6w5hy337y6tmciawqkk.py +62 -0
  44. SpecForge-ext/cache/compiled_kernels/fh/ebe6017c015020b128565a146c63c01eb1d20ffe6e82484e1c26bb63be24756a.best_config +1 -0
  45. SpecForge-ext/cache/compiled_kernels/fl/cfl7aqky4mcwhud5rcyx5e6sredhx2vbbrykfa5v67vwkgveygd5.py +159 -0
  46. SpecForge-ext/cache/compiled_kernels/gn/cgnmjxikvi5ulcyj3uozif3le5hd26kw2kjhkcbhupqgudqi3bwn.py +72 -0
  47. SpecForge-ext/cache/compiled_kernels/gn/cgnsrigp6qu2lbqq76g27kshvt2bzkyjnupza5ds7znhjxrnwhif.py +49 -0
  48. SpecForge-ext/cache/compiled_kernels/gv/cgva67py5joafltlxqsoz5uf2a7qh2rakl35e3wsc4nbdlv75anq.py +835 -0
  49. SpecForge-ext/cache/compiled_kernels/gv/cgvbha5mvyldninvrzu5qgbcoz6irvhuphtcgrde6mr733uggxnb.py +543 -0
  50. SpecForge-ext/cache/compiled_kernels/gy/94796f3e1399aa6e798adba6b896031b3152400abd45f5ee80e2ec3df79f0b97.best_config +1 -0
SpecForge-ext/cache/compiled_kernels/2j/c2jyvidugg4t2zvjimwjrb4yacpc5zz5qifflapqv3x2b34cxuq7.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.pointwise(
11
+ size_hints={'x': 67108864},
12
+ filename=__file__,
13
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'ks4': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
15
+ min_elem_per_thread=0
16
+ )
17
+ @triton.jit
18
+ def triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, ks4, xnumel, XBLOCK : tl.constexpr):
19
+ xoffset = tl.program_id(0) * XBLOCK
20
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
21
+ xmask = xindex < xnumel
22
+ x4 = xindex
23
+ x2 = ((xindex // ks0) % ks1)
24
+ x0 = (xindex % ks3)
25
+ x5 = xindex // ks3
26
+ tmp0 = tl.load(in_ptr0 + (x4), xmask, eviction_policy='evict_last').to(tl.float32)
27
+ tmp1 = tl.load(in_ptr1 + (x2), xmask, eviction_policy='evict_last')
28
+ tmp2 = ks2
29
+ tmp3 = tmp1 + tmp2
30
+ tmp4 = tmp1 < 0
31
+ tmp5 = tl.where(tmp4, tmp3, tmp1)
32
+ tl.device_assert(((0 <= tmp5) & (tmp5 < ks2)) | ~(xmask), "index out of bounds: 0 <= tmp5 < ks2")
33
+ tmp7 = tl.load(in_ptr2 + (x0 + ks3*tmp5), xmask, eviction_policy='evict_last').to(tl.float32)
34
+ tmp8 = tmp0 * tmp7
35
+ tmp9 = x0
36
+ tmp10 = tl.full([1], 0, tl.int64)
37
+ tmp11 = tmp9 >= tmp10
38
+ tmp12 = ks3 + (-1)*(ks3 // 2)
39
+ tmp13 = tmp9 < tmp12
40
+ tmp14 = tl.load(in_ptr0 + (ks3*x5 + (ks3 // 2) + (x0)), tmp13 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
41
+ tmp15 = -tmp14
42
+ tmp16 = tl.full(tmp15.shape, 0.0, tmp15.dtype)
43
+ tmp17 = tl.where(tmp13, tmp15, tmp16)
44
+ tmp18 = tmp9 >= tmp12
45
+ tmp19 = ks3
46
+ tmp20 = tmp9 < tmp19
47
+ tmp21 = tl.load(in_ptr0 + (ks3*x5 + (x0 + ((-1)*ks3) + (ks3 // 2))), tmp18 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
48
+ tmp22 = tl.where(tmp13, tmp17, tmp21)
49
+ tmp23 = ks4
50
+ tmp24 = tmp1 + tmp23
51
+ tmp25 = tl.where(tmp4, tmp24, tmp1)
52
+ tl.device_assert(((0 <= tmp25) & (tmp25 < ks4)) | ~(xmask), "index out of bounds: 0 <= tmp25 < ks4")
53
+ tmp27 = tl.load(in_ptr3 + (x0 + ks3*tmp25), xmask, eviction_policy='evict_last').to(tl.float32)
54
+ tmp28 = tmp22 * tmp27
55
+ tmp29 = tmp8 + tmp28
56
+ tl.store(out_ptr0 + (x4), tmp29, xmask)
SpecForge-ext/cache/compiled_kernels/2w/c2w4lzmy7ekz2bm6dysknbimkeh24xjafkehykmvy5slelzkuz4g.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AOT ID: ['11_inference']
2
+ from ctypes import c_void_p, c_long, c_int
3
+ import torch
4
+ import math
5
+ import random
6
+ import os
7
+ import tempfile
8
+ from math import inf, nan
9
+ from cmath import nanj
10
+ from torch._inductor.hooks import run_intermediate_hooks
11
+ from torch._inductor.utils import maybe_profile
12
+ from torch._inductor.codegen.memory_planning import _align as align
13
+ from torch import device, empty_strided
14
+ from torch._inductor.async_compile import AsyncCompile
15
+ from torch._inductor.select_algorithm import extern_kernels
16
+ import triton
17
+ import triton.language as tl
18
+ from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
19
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
20
+
21
+ aten = torch.ops.aten
22
+ inductor_ops = torch.ops.inductor
23
+ _quantized = torch.ops._quantized
24
+ assert_size_stride = torch._C._dynamo.guards.assert_size_stride
25
+ assert_alignment = torch._C._dynamo.guards.assert_alignment
26
+ empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
27
+ empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
28
+ empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
29
+ empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
30
+ empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
31
+ reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
32
+ alloc_from_pool = torch.ops.inductor._alloc_from_pool
33
+ async_compile = AsyncCompile()
34
+ empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
35
+
36
+
37
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ce/cceutci466trbhyuepvkfxihcvlq4wgwo5on5qew43oksrg2qng2.py
38
+ # Topologically Sorted Source Nodes: [target_head, target_p], Original ATen: [aten._to_copy, prims.prepare_softmax_online, aten.sub, aten.exp, aten._softmax]
39
+ # Source node to ATen node mapping:
40
+ # target_head => convert_element_type
41
+ # target_p => div
42
+ # Graph fragment:
43
+ # %arg1_1 : Tensor "bf16[2, s67, 32000][32000*s67, 32000, 1]cuda:2" = PlaceHolder[target=arg1_1]
44
+ # %getitem : Tensor "f32[2, s67, 1][s67, 1, 2*s67]cuda:2" = PlaceHolder[target=getitem]
45
+ # %getitem_1 : Tensor "f32[2, s67, 1][s67, 1, 2*s67]cuda:2" = PlaceHolder[target=getitem_1]
46
+ # %convert_element_type : Tensor "f32[2, s67, 32000][32000*s67, 32000, 1]cuda:2"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%arg1_1, torch.float32), kwargs = {})
47
+ # %prepare_softmax_online_default : [num_users=2] = call_function[target=torch.ops.prims.prepare_softmax_online.default](args = (%convert_element_type, 2), kwargs = {})
48
+ # %sub_tensor : Tensor "f32[2, s67, 32000][32000*s67, 32000, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%convert_element_type, %getitem), kwargs = {})
49
+ # %exp_default : Tensor "f32[2, s67, 32000][32000*s67, 32000, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.exp.default](args = (%sub_tensor,), kwargs = {})
50
+ # %div : Tensor "f32[2, s67, 32000][32000*s67, 32000, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%exp_default, %getitem_1), kwargs = {})
51
+ # return %getitem,%getitem_1,%div
52
+ triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0 = async_compile.triton('triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0', '''
53
+ import triton
54
+ import triton.language as tl
55
+
56
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
57
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
58
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
59
+ triton_helpers.set_driver_to_gpu()
60
+
61
+ @triton_heuristics.reduction(
62
+ size_hints={'x': 4096, 'r0_': 32768},
63
+ reduction_hint=ReductionHint.INNER,
64
+ filename=__file__,
65
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr2': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
66
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
67
+ )
68
+ @triton.jit
69
+ def triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0(in_ptr0, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
70
+ r0_numel = 32000
71
+ rnumel = r0_numel
72
+ RBLOCK: tl.constexpr = R0_BLOCK
73
+ xoffset = tl.program_id(0) * XBLOCK
74
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
75
+ xmask = xindex < xnumel
76
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
77
+ rbase = r0_base
78
+ x0 = xindex
79
+ _tmp3_max = tl.full([XBLOCK, R0_BLOCK], float('-inf'), tl.float32)
80
+ _tmp3_sum = tl.zeros([XBLOCK, R0_BLOCK], tl.float32)
81
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
82
+ r0_index = r0_offset + r0_base
83
+ r0_mask = r0_index < r0_numel
84
+ roffset = r0_offset
85
+ rindex = r0_index
86
+ r0_1 = r0_index
87
+ tmp0 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
88
+ tmp1 = tmp0.to(tl.float32)
89
+ tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK])
90
+
91
+ _tmp3_max_next, _tmp3_sum_next = triton_helpers.online_softmax_combine(
92
+ _tmp3_max, _tmp3_sum, tmp2, False
93
+ )
94
+
95
+ _tmp3_max = tl.where(r0_mask & xmask, _tmp3_max_next, _tmp3_max)
96
+ _tmp3_sum = tl.where(r0_mask & xmask, _tmp3_sum_next, _tmp3_sum)
97
+
98
+ tmp3, tmp4 = triton_helpers.online_softmax_reduce(
99
+ _tmp3_max, _tmp3_sum, 1, False)
100
+ tmp3 = tmp3[:, None]
101
+ tmp4 = tmp4[:, None]
102
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
103
+ r0_index = r0_offset + r0_base
104
+ r0_mask = r0_index < r0_numel
105
+ roffset = r0_offset
106
+ rindex = r0_index
107
+ r0_1 = r0_index
108
+ tmp5 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
109
+ tmp6 = tmp5.to(tl.float32)
110
+ tmp7 = tmp6 - tmp3
111
+ tmp8 = libdevice.exp(tmp7)
112
+ tmp9 = (tmp8 / tmp4)
113
+ tl.store(out_ptr2 + (r0_1 + 32000*x0), tmp9, r0_mask & xmask)
114
+ ''', device_str='cuda')
115
+
116
+
117
+ async_compile.wait(globals())
118
+ del async_compile
119
+
120
+ class Runner:
121
+ def __init__(self, partitions):
122
+ self.partitions = partitions
123
+
124
+ def recursively_apply_fns(self, fns):
125
+ new_callables = []
126
+ for fn, c in zip(fns, self.partitions):
127
+ new_callables.append(fn(c))
128
+ self.partitions = new_callables
129
+
130
+ def call(self, args):
131
+ arg0_1, arg1_1 = args
132
+ args.clear()
133
+ s67 = arg0_1
134
+ assert_size_stride(arg1_1, (2, s67, 32000), (32000*s67, 32000, 1))
135
+ with torch.cuda._DeviceGuard(2):
136
+ torch.cuda.set_device(2)
137
+ buf2 = empty_strided_cuda((2, s67, 32000), (32000*s67, 32000, 1), torch.float32)
138
+ # Topologically Sorted Source Nodes: [target_head, target_p], Original ATen: [aten._to_copy, prims.prepare_softmax_online, aten.sub, aten.exp, aten._softmax]
139
+ triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0_xnumel = 2*s67
140
+ stream2 = get_raw_stream(2)
141
+ triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0.run(arg1_1, buf2, triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0_xnumel, 32000, stream=stream2)
142
+ del arg1_1
143
+ return (buf2, )
144
+
145
+ runner = Runner(partitions=[])
146
+ call = runner.call
147
+ recursively_apply_fns = runner.recursively_apply_fns
148
+
149
+
150
+ def benchmark_compiled_module(times=10, repeat=10):
151
+ from torch._dynamo.testing import rand_strided
152
+ from torch._inductor.utils import print_performance
153
+ arg0_1 = 1856
154
+ arg1_1 = rand_strided((2, 1856, 32000), (59392000, 32000, 1), device='cuda:2', dtype=torch.bfloat16)
155
+ fn = lambda: call([arg0_1, arg1_1])
156
+ return print_performance(fn, times=times, repeat=repeat)
157
+
158
+
159
+ if __name__ == "__main__":
160
+ from torch._inductor.wrapper_benchmark import compiled_module_main
161
+ compiled_module_main('None', benchmark_compiled_module)
SpecForge-ext/cache/compiled_kernels/37/c37hhpqpo6nmqnyehjfjbhe3e5gtvdqkytyzohipcst7jy7ol4iy.py ADDED
@@ -0,0 +1,675 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AOT ID: ['6_forward']
2
+ from ctypes import c_void_p, c_long, c_int
3
+ import torch
4
+ import math
5
+ import random
6
+ import os
7
+ import tempfile
8
+ from math import inf, nan
9
+ from cmath import nanj
10
+ from torch._inductor.hooks import run_intermediate_hooks
11
+ from torch._inductor.utils import maybe_profile
12
+ from torch._inductor.codegen.memory_planning import _align as align
13
+ from torch import device, empty_strided
14
+ from torch._inductor.async_compile import AsyncCompile
15
+ from torch._inductor.select_algorithm import extern_kernels
16
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
17
+ import triton
18
+ import triton.language as tl
19
+ from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
20
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
21
+
22
+ aten = torch.ops.aten
23
+ inductor_ops = torch.ops.inductor
24
+ _quantized = torch.ops._quantized
25
+ assert_size_stride = torch._C._dynamo.guards.assert_size_stride
26
+ assert_alignment = torch._C._dynamo.guards.assert_alignment
27
+ empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
28
+ empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
29
+ empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
30
+ empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
31
+ empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
32
+ reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
33
+ alloc_from_pool = torch.ops.inductor._alloc_from_pool
34
+ async_compile = AsyncCompile()
35
+ empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
36
+
37
+
38
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/kd/ckd3pok5sro2yqebn2h6a3e2gj73iwa2hipdtvfjxehawlkn6dqo.py
39
+ # Topologically Sorted Source Nodes: [flex_attention], Original ATen: []
40
+ # Source node to ATen node mapping:
41
+ # flex_attention => flex_attention
42
+ # Graph fragment:
43
+ # %primals_1 : Tensor "bf16[8, 32, 2048, 128][8388608, 128, 4096, 1]cuda:2" = PlaceHolder[target=primals_1]
44
+ # %primals_2 : Tensor "bf16[8, 8, 2048, 128][2097152, 262144, 128, 1]cuda:2" = PlaceHolder[target=primals_2]
45
+ # %primals_3 : Tensor "bf16[8, 8, 2048, 128][2097152, 262144, 128, 1]cuda:2" = PlaceHolder[target=primals_3]
46
+ # %getitem_1 : Tensor "f32[8, 32, 2048][65536, 2048, 1]cuda:2" = PlaceHolder[target=getitem_1]
47
+ # %buf1 : Tensor "f32[8, 32, 2048][65536, 2048, 1]cuda:2" = PlaceHolder[target=buf1]
48
+ # %primals_5 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:2" = PlaceHolder[target=primals_5]
49
+ # %primals_4 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:2" = PlaceHolder[target=primals_4]
50
+ # %primals_7 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:2" = PlaceHolder[target=primals_7]
51
+ # %primals_8 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:2" = PlaceHolder[target=primals_8]
52
+ # %primals_6 : Tensor "i64[8][1]cuda:2" = PlaceHolder[target=primals_6]
53
+ # %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%primals_1, %primals_2, %primals_3, %sdpa_score0, (2048, 2048, %primals_5, %primals_4, %primals_7, %primals_8, %primals_9, %primals_10, %primals_11, %primals_12, 128, 128, %sdpa_mask0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_6,)), kwargs = {})
54
+ # return %getitem
55
+ triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', '''
56
+ import triton
57
+ import triton.language as tl
58
+
59
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
60
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
61
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
62
+
63
+ @triton_heuristics.template(
64
+
65
+ num_stages=3,
66
+ num_warps=8,
67
+ triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'in_ptr9': '*i64', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]},
68
+ inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
69
+
70
+ )
71
+ @triton.jit
72
+ def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0):
73
+ PRESCALE_QK : tl.constexpr = False
74
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
75
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
76
+ WRITE_DQ : tl.constexpr = True
77
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
78
+ OUTPUT_MAX : tl.constexpr = False
79
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
80
+ IS_DIVISIBLE : tl.constexpr = True
81
+ SM_SCALE : tl.constexpr = 0.08838834764831843
82
+ GQA_SHARED_HEADS : tl.constexpr = 4
83
+ HAS_FULL_BLOCKS : tl.constexpr = True
84
+ QK_HEAD_DIM : tl.constexpr = 128
85
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
86
+ V_HEAD_DIM : tl.constexpr = 128
87
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
88
+ SAFE_HEAD_DIM : tl.constexpr = True
89
+ USE_TMA : tl.constexpr = False
90
+ BLOCK_M : tl.constexpr = 128
91
+ BLOCK_N : tl.constexpr = 64
92
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
93
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
94
+ INDEX_DTYPE : tl.constexpr = tl.int32
95
+ Q = arg_Q
96
+ K = arg_K
97
+ V = arg_V
98
+ LSE = arg_LSE
99
+ MAX = arg_MAX
100
+ KV_NUM_BLKS = arg_KV_NUM_BLKS
101
+ KV_IDX = arg_KV_IDX
102
+ FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
103
+ FULL_KV_IDX = arg_FULL_KV_IDX
104
+
105
+ # Sub notation for this kernel:
106
+ #
107
+ # Q: Query, K: Key, V: Value
108
+ # M: Number of queries, N: Number of keys/values, D: Model dimension
109
+ # QK_HEAD_DIM: The dimension of the query and key embeddings
110
+ # V_HEAD_DIM: The dimension of the value embeddings
111
+ # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head
112
+ # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
113
+ #
114
+ # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
115
+ # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
116
+ # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
117
+ # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
118
+ # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
119
+ #
120
+ # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad
121
+ #
122
+ # (Modifiable) Performance tuning options
123
+ # BLOCK_M: The thread block size across the seqlen dim of Q.
124
+ # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block.
125
+
126
+ # The below are kernel options that can be applied for certain score_mods,
127
+ # or involve a numerics vs. perf tradeoff
128
+ # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
129
+ # about 20% more numerical error, but slightly faster.
130
+ # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row
131
+ # is not masked out? If so, we can skip an extra safety check
132
+ # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are
133
+ # contiguous? If so, we don't need to do an indirect jump for every block
134
+
135
+ tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0)
136
+ tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0)
137
+
138
+ # Define strides of inputs
139
+ stride_qz, stride_qh, stride_qm, stride_qk = 8388608, 128, 4096, 1
140
+ stride_kz, stride_kh, stride_kn, stride_kk = 2097152, 262144, 128, 1
141
+ stride_vz, stride_vh, stride_vn, stride_vk = 2097152, 262144, 128, 1
142
+
143
+ ZQ = 8
144
+ HQ = 32
145
+ Q_LEN = 2048
146
+ ZKV = 8
147
+ KV_LEN = 2048
148
+
149
+ MATMUL_PRECISION = Q.dtype.element_ty
150
+
151
+ q_start = tl.program_id(0).to(INDEX_DTYPE)
152
+ off_zq = tl.program_id(1).to(INDEX_DTYPE)
153
+ off_hq = tl.program_id(2).to(INDEX_DTYPE)
154
+
155
+ # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq.
156
+ # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0.
157
+ off_zkv = off_zq % ZKV
158
+ off_hkv = off_hq // GQA_SHARED_HEADS
159
+ off_g = off_hq % GQA_SHARED_HEADS
160
+
161
+ q_offset = off_zq * stride_qz + off_hq * stride_qh
162
+ k_offset = off_zkv * stride_kz + off_hkv * stride_kh
163
+ v_offset = off_zkv * stride_vz + off_hkv * stride_vh
164
+
165
+ Q = Q + q_offset
166
+ K = K + k_offset
167
+ V = V + v_offset
168
+
169
+ # Setting up the TMA descriptors for Q, K, V
170
+ desc_q = None
171
+ desc_k = None
172
+ desc_v = None
173
+
174
+ SPARSE_Z = 8
175
+ SPARSE_HQ = 1
176
+
177
+ sparse_idx_z = off_zq % SPARSE_Z
178
+ sparse_idx_hq = off_hq % SPARSE_HQ
179
+
180
+ SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M)
181
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
182
+
183
+ stride_kv_num_blks_h = 16
184
+ stride_kv_idx_h = 256
185
+ stride_kv_idx_m = 16
186
+
187
+ # initialize pointer to m and l
188
+ m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
189
+ l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
190
+ acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
191
+
192
+ offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
193
+
194
+ # KV_IDX and KV_NUM_BLKS are always contiguous.
195
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq
196
+ sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE
197
+ sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950
198
+ offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
199
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
200
+ q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
201
+
202
+ # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
203
+ # We don't know anything "special" about these blocks, so we need to apply
204
+ # both score_mod and mask_mod to it
205
+ kv_indices = KV_IDX + sparse_kv_idx_offset
206
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
207
+ kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
208
+ block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
209
+
210
+
211
+ # K and V pointers will be passed directly to forward_inner
212
+
213
+ offs_n = kv_start + tl.arange(0, BLOCK_N)
214
+
215
+
216
+ acc, l_i, m_i = forward_inner(
217
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0,
218
+ q, K, V,
219
+ desc_k, desc_v, Q_LEN, KV_LEN,
220
+ acc, l_i, m_i,
221
+ off_zq, off_hq, offs_m[:, None], offs_n[None, :],
222
+ kv_start,
223
+ kv_indices, kv_num_blocks,
224
+ 0, block_n_end,
225
+ MATMUL_PRECISION,
226
+ stride_kk, stride_kn, stride_vn, stride_vk,
227
+ IS_FULL_BLOCKS=False,
228
+ )
229
+
230
+ # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
231
+ # We know these blocks are guaranteed to be "full", so we don't need to
232
+ # apply mask_mod to them - only score_mod
233
+ if HAS_FULL_BLOCKS:
234
+ # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
235
+ kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
236
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
237
+ kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
238
+ block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
239
+ # K and V pointers will be passed directly to forward_inner
240
+ offs_n = kv_start + tl.arange(0, BLOCK_N)
241
+
242
+ acc, l_i, m_i = forward_inner(
243
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0,
244
+ q, K, V,
245
+ desc_k, desc_v, Q_LEN, KV_LEN,
246
+ acc, l_i, m_i,
247
+ off_zq, off_hq, offs_m[:, None], offs_n[None, :],
248
+ kv_start,
249
+ kv_indices, kv_num_blocks,
250
+ 0, block_n_end,
251
+ MATMUL_PRECISION,
252
+ stride_kk, stride_kn, stride_vn, stride_vk,
253
+ IS_FULL_BLOCKS=True,
254
+ )
255
+
256
+
257
+ # [Note] Handle fully masked out rows:
258
+ # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf.
259
+ # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step
260
+ l_i = tl.where(l_i == 0.0, 1, l_i)
261
+
262
+ acc = acc / l_i[:, None]
263
+ idx_zq = tl.program_id(1).to(INDEX_DTYPE)
264
+ idx_hq = tl.program_id(2).to(INDEX_DTYPE)
265
+ idx_m = offs_m[:, None].to(INDEX_DTYPE)
266
+ idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE)
267
+
268
+ mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM)
269
+
270
+ tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED])
271
+ xindex = idx_d + 128*idx_m + 262144*idx_hq + 8388608*idx_zq
272
+ tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m + 8388608*idx_zq, acc.shape)), acc, mask)
273
+
274
+ if OUTPUT_LOGSUMEXP:
275
+ off_hz = off_zq * HQ + off_hq
276
+ l_ptrs = LSE + off_hz * Q_LEN + offs_m
277
+ lse = m_i + tl.math.log2(l_i)
278
+ if IS_DIVISIBLE:
279
+ tl.store(l_ptrs, lse)
280
+ else:
281
+ tl.store(l_ptrs, lse, mask=offs_m < Q_LEN)
282
+
283
+ if OUTPUT_MAX:
284
+ off_hz = off_zq * HQ + off_hq
285
+ max_ptrs = MAX + off_hz * Q_LEN + offs_m
286
+ if IS_DIVISIBLE:
287
+ tl.store(max_ptrs, m_i)
288
+ else:
289
+ tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN)
290
+
291
+
292
+ # Utility triton funcs
293
+ @triton.jit
294
+ def get_offset_for_next_block(
295
+ loop_iter, col_indices, total_blocks,
296
+ SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
297
+ BLOCKS_ARE_CONTIGUOUS: tl.constexpr
298
+ ):
299
+ if BLOCKS_ARE_CONTIGUOUS:
300
+ return BLOCK
301
+ cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
302
+ cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
303
+ next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
304
+ needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
305
+ jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
306
+ offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
307
+ return offset
308
+
309
+ @triton.jit
310
+ def get_bounded_indices(indices, max_len=None):
311
+ return indices % max_len if max_len is not None else indices
312
+
313
+ @triton.jit
314
+ def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
315
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
316
+ return tl.load(block_ptr)
317
+ elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
318
+ return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
319
+ elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
320
+ return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
321
+ else:
322
+ return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
323
+
324
+ @triton.jit
325
+ def load_checked_2d(
326
+ ptr,
327
+ offs_m,
328
+ offs_n,
329
+ stride_m,
330
+ stride_n,
331
+ IS_DIVISIBLE_M: tl.constexpr,
332
+ IS_DIVISIBLE_N: tl.constexpr,
333
+ M_LEN: tl.constexpr,
334
+ N_LEN: tl.constexpr,
335
+ ):
336
+ # Calculate final pointer if strides are provided
337
+ if stride_m is not None and stride_n is not None:
338
+ ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
339
+
340
+ # Handle all masking cases
341
+ if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
342
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
343
+ elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
344
+ return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
345
+ elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
346
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
347
+ else: # Both divisible
348
+ return tl.load(ptr)
349
+
350
+
351
+ # Common Imports
352
+ @triton.jit
353
+ def forward_block_mn(
354
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0,
355
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
356
+ # accumulated values
357
+ acc, l_i, m_i,
358
+ # Offsets
359
+ off_z, off_h, offs_m, offs_n,
360
+ # Offsets needed for TMA loads
361
+ kv_start,
362
+ kv_offset,
363
+ MATMUL_PRECISION, RCP_LN2,
364
+ # Strides for K and V
365
+ stride_kk, stride_kn, stride_vn, stride_vk,
366
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
367
+
368
+ ):
369
+ # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
370
+ PRESCALE_QK : tl.constexpr = False
371
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
372
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
373
+ WRITE_DQ : tl.constexpr = True
374
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
375
+ OUTPUT_MAX : tl.constexpr = False
376
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
377
+ IS_DIVISIBLE : tl.constexpr = True
378
+ SM_SCALE : tl.constexpr = 0.08838834764831843
379
+ GQA_SHARED_HEADS : tl.constexpr = 4
380
+ HAS_FULL_BLOCKS : tl.constexpr = True
381
+ QK_HEAD_DIM : tl.constexpr = 128
382
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
383
+ V_HEAD_DIM : tl.constexpr = 128
384
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
385
+ SAFE_HEAD_DIM : tl.constexpr = True
386
+ USE_TMA : tl.constexpr = False
387
+ BLOCK_M : tl.constexpr = 128
388
+ BLOCK_N : tl.constexpr = 64
389
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
390
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
391
+ INDEX_DTYPE : tl.constexpr = tl.int32
392
+
393
+
394
+ # -- load k --
395
+ # NB reversed order to since K is transposed
396
+ kv_base_offset = kv_start + kv_offset
397
+
398
+ # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N]
399
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
400
+ offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N)
401
+ k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
402
+
403
+ k = tl.trans(k)
404
+ # -- compute qk ---
405
+ qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2.
406
+ if not PRESCALE_QK:
407
+ qk *= SM_SCALE
408
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
409
+ # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements,
410
+ # which is larger than the actual number of elements. To avoid access memory out of bound,
411
+ # we need to mask out the elements that are out of Q_LEN & KV_LEN.
412
+ m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None)
413
+ n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None)
414
+
415
+ tmp0 = (qk)
416
+ post_mod_scores = tmp0
417
+
418
+
419
+ if CHECK_BLOCK_BOUNDARY:
420
+ # Mask out the elements that are out of the KV_LEN for non divisible seqlen.
421
+ post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf"))
422
+
423
+ if not IS_FULL_BLOCKS:
424
+ tmp1 = tl.full([1], False, tl.int1)
425
+ tmp2 = (m)
426
+ tmp3 = (n)
427
+ tmp4 = tmp2 >= tmp3
428
+ tmp5 = tmp3.to(tl.int64)
429
+ tmp6 = (off_z)
430
+ tmp7 = tl.load(in_ptr9 + tmp6)
431
+ tmp8 = tmp5 < tmp7
432
+ tmp9 = tmp2.to(tl.int64)
433
+ tmp10 = tmp9 < tmp7
434
+ tmp11 = tmp8 & tmp10
435
+ tmp12 = tmp4 & tmp11
436
+ tmp13 = tmp1 | tmp12
437
+ tmp14 = tl.full([1], 2048, tl.int32)
438
+ tmp15 = tmp3 >= tmp14
439
+ tmp16 = (tmp3 % tmp14)
440
+ tmp17 = tl.full([1], 0, tl.int32)
441
+ tmp18 = tmp16 != tmp17
442
+ tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0
443
+ tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0
444
+ tmp21 = tmp19 != tmp20
445
+ tmp22 = tmp18 & tmp21
446
+ tmp23 = tmp16 + tmp14
447
+ tmp24 = tl.where(tmp22, tmp23, tmp16)
448
+ tmp25 = tmp24.to(tl.int64)
449
+ tmp26 = tmp25 < tmp7
450
+ tmp27 = tmp15 & tmp26
451
+ tmp28 = tmp3 - tmp2
452
+ tmp29 = (tmp28 % tmp14)
453
+ tmp30 = tmp29 != tmp17
454
+ tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0
455
+ tmp32 = tmp31 != tmp20
456
+ tmp33 = tmp30 & tmp32
457
+ tmp34 = tmp29 + tmp14
458
+ tmp35 = tl.where(tmp33, tmp34, tmp29)
459
+ tmp36 = tmp35 == tmp17
460
+ tmp37 = tmp27 & tmp36
461
+ tmp38 = tmp13 | tmp37
462
+ mask_mod_output = tmp38
463
+
464
+
465
+ if CHECK_BLOCK_BOUNDARY:
466
+ mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False)
467
+ # apply mask for partially unmasked blocks
468
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
469
+
470
+ if not PRESCALE_QK:
471
+ post_mod_scores *= RCP_LN2
472
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
473
+
474
+ # -- compute scaling constant ---
475
+ m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1))
476
+ if not ROWS_GUARANTEED_SAFE:
477
+ masked_out_rows = (m_ij == float("-inf"))
478
+ m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
479
+ else:
480
+ m_ij_masked = m_ij
481
+
482
+ alpha = tl.math.exp2(m_i - m_ij_masked)
483
+ p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None])
484
+
485
+ # NB: l_i update is pulled up here since it's a bit faster
486
+ # NB: For headdim=256, it's faster to move it back down to after m_i =
487
+ # m_ij
488
+ l_i = l_i * alpha + tl.sum(p, 1)
489
+ # # -- scale and update acc --
490
+ acc = acc * alpha[:, None]
491
+ # Calculate offsets for V loading - reuse kv_base_offset from K loading
492
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
493
+ v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
494
+ acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION)
495
+
496
+ # -- update m_i
497
+ m_i = m_ij
498
+
499
+ return acc, l_i, m_i
500
+
501
+ @triton.jit
502
+ def forward_inner(
503
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0,
504
+ q, K, V,
505
+ desc_k, desc_v, Q_LEN, KV_LEN,
506
+ # accumulated values
507
+ acc, l_i, m_i,
508
+ # Offsets used as inputs to score_mod & mask_mod
509
+ # of size [BLOCK_M, BLOCK_N] or scalar.
510
+ off_z, off_h, offs_m, offs_n,
511
+ # Offsets needed for TMA loads
512
+ kv_start,
513
+ # blocksparse data
514
+ kv_indices, kv_num_blocks,
515
+ # start kv and end kv block
516
+ block_n_start, block_n_end,
517
+ MATMUL_PRECISION,
518
+ # Strides for K and V
519
+ stride_kk, stride_kn, stride_vn, stride_vk,
520
+ IS_FULL_BLOCKS,
521
+ ):
522
+ # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
523
+ PRESCALE_QK : tl.constexpr = False
524
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
525
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
526
+ WRITE_DQ : tl.constexpr = True
527
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
528
+ OUTPUT_MAX : tl.constexpr = False
529
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
530
+ IS_DIVISIBLE : tl.constexpr = True
531
+ SM_SCALE : tl.constexpr = 0.08838834764831843
532
+ GQA_SHARED_HEADS : tl.constexpr = 4
533
+ HAS_FULL_BLOCKS : tl.constexpr = True
534
+ QK_HEAD_DIM : tl.constexpr = 128
535
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
536
+ V_HEAD_DIM : tl.constexpr = 128
537
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
538
+ SAFE_HEAD_DIM : tl.constexpr = True
539
+ USE_TMA : tl.constexpr = False
540
+ BLOCK_M : tl.constexpr = 128
541
+ BLOCK_N : tl.constexpr = 64
542
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
543
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
544
+ INDEX_DTYPE : tl.constexpr = tl.int32
545
+
546
+
547
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
548
+ RCP_LN2: tl.constexpr = 1.44269504
549
+
550
+ if PRESCALE_QK:
551
+ q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
552
+
553
+ kv_offset = 0
554
+
555
+ # loop over k, v and update accumulator until block_n_end
556
+ for start_n in range(block_n_start, block_n_end):
557
+ # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention.
558
+ if IS_DIVISIBLE:
559
+ acc, l_i, m_i = forward_block_mn(
560
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0,
561
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
562
+ # accumulated values
563
+ acc, l_i, m_i,
564
+ # Offsets
565
+ off_z, off_h, offs_m, offs_n,
566
+ # Offsets needed for TMA loads
567
+ kv_start,
568
+ kv_offset,
569
+ MATMUL_PRECISION, RCP_LN2,
570
+ # Strides for K and V
571
+ stride_kk, stride_kn, stride_vn, stride_vk,
572
+ IS_FULL_BLOCKS,
573
+ )
574
+ else:
575
+ # Benchmark shows even we applied mod & mask to each block for non divisible seqlen,
576
+ # it's on par or slightly faster than only applying to the last block in fwd.
577
+ # However, we choose different strategy for bwd, where we only apply mod & mask
578
+ # to the last block because it's faster a lot.
579
+ acc, l_i, m_i = forward_block_mn(
580
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0,
581
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
582
+ # accumulated values
583
+ acc, l_i, m_i,
584
+ # Offsets
585
+ off_z, off_h, offs_m, offs_n,
586
+ # Offsets needed for TMA loads
587
+ kv_start,
588
+ kv_offset,
589
+ MATMUL_PRECISION, RCP_LN2,
590
+ # Strides for K and V
591
+ stride_kk, stride_kn, stride_vn, stride_vk,
592
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
593
+ )
594
+
595
+
596
+
597
+ offset = get_offset_for_next_block(
598
+ start_n, kv_indices, kv_num_blocks,
599
+ SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS
600
+ )
601
+
602
+ offs_n = offs_n + offset
603
+ kv_offset += offset
604
+
605
+
606
+ return acc, l_i, m_i
607
+ ''', device_str='cuda')
608
+
609
+
610
+ async_compile.wait(globals())
611
+ del async_compile
612
+
613
+ class Runner:
614
+ def __init__(self, partitions):
615
+ self.partitions = partitions
616
+
617
+ def recursively_apply_fns(self, fns):
618
+ new_callables = []
619
+ for fn, c in zip(fns, self.partitions):
620
+ new_callables.append(fn(c))
621
+ self.partitions = new_callables
622
+
623
+ def call(self, args):
624
+ primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12 = args
625
+ args.clear()
626
+ assert_size_stride(primals_1, (8, 32, 2048, 128), (8388608, 128, 4096, 1))
627
+ assert_size_stride(primals_2, (8, 8, 2048, 128), (2097152, 262144, 128, 1))
628
+ assert_size_stride(primals_3, (8, 8, 2048, 128), (2097152, 262144, 128, 1))
629
+ assert_size_stride(primals_4, (8, 1, 16, 16), (256, 256, 16, 1))
630
+ assert_size_stride(primals_5, (8, 1, 16), (16, 16, 1))
631
+ assert_size_stride(primals_6, (8, ), (1, ))
632
+ assert_size_stride(primals_7, (8, 1, 16), (16, 16, 1))
633
+ assert_size_stride(primals_8, (8, 1, 16, 16), (256, 256, 16, 1))
634
+ assert_size_stride(primals_9, (8, 1, 16), (16, 16, 1))
635
+ assert_size_stride(primals_10, (8, 1, 16, 16), (256, 256, 16, 1))
636
+ assert_size_stride(primals_11, (8, 1, 16), (16, 16, 1))
637
+ assert_size_stride(primals_12, (8, 1, 16, 16), (256, 256, 16, 1))
638
+ with torch.cuda._DeviceGuard(2):
639
+ torch.cuda.set_device(2)
640
+ buf0 = empty_strided_cuda((8, 32, 2048), (65536, 2048, 1), torch.float32)
641
+ buf1 = empty_strided_cuda((8, 32, 2048), (65536, 2048, 1), torch.float32)
642
+ buf2 = empty_strided_cuda((8, 32, 2048, 128), (8388608, 128, 4096, 1), torch.bfloat16)
643
+ # Topologically Sorted Source Nodes: [flex_attention], Original ATen: []
644
+ stream2 = get_raw_stream(2)
645
+ triton_tem_fused_0.run(primals_1, primals_2, primals_3, buf0, buf1, primals_5, primals_4, primals_7, primals_8, primals_6, buf2, 16, 8, 32, stream=stream2)
646
+ del buf1
647
+ return (buf2, primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, buf2, buf0, )
648
+
649
+ runner = Runner(partitions=[])
650
+ call = runner.call
651
+ recursively_apply_fns = runner.recursively_apply_fns
652
+
653
+
654
+ def benchmark_compiled_module(times=10, repeat=10):
655
+ from torch._dynamo.testing import rand_strided
656
+ from torch._inductor.utils import print_performance
657
+ primals_1 = rand_strided((8, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:2', dtype=torch.bfloat16)
658
+ primals_2 = rand_strided((8, 8, 2048, 128), (2097152, 262144, 128, 1), device='cuda:2', dtype=torch.bfloat16)
659
+ primals_3 = rand_strided((8, 8, 2048, 128), (2097152, 262144, 128, 1), device='cuda:2', dtype=torch.bfloat16)
660
+ primals_4 = rand_strided((8, 1, 16, 16), (256, 256, 16, 1), device='cuda:2', dtype=torch.int32)
661
+ primals_5 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:2', dtype=torch.int32)
662
+ primals_6 = rand_strided((8, ), (1, ), device='cuda:2', dtype=torch.int64)
663
+ primals_7 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:2', dtype=torch.int32)
664
+ primals_8 = rand_strided((8, 1, 16, 16), (256, 256, 16, 1), device='cuda:2', dtype=torch.int32)
665
+ primals_9 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:2', dtype=torch.int32)
666
+ primals_10 = rand_strided((8, 1, 16, 16), (256, 256, 16, 1), device='cuda:2', dtype=torch.int32)
667
+ primals_11 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:2', dtype=torch.int32)
668
+ primals_12 = rand_strided((8, 1, 16, 16), (256, 256, 16, 1), device='cuda:2', dtype=torch.int32)
669
+ fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12])
670
+ return print_performance(fn, times=times, repeat=repeat)
671
+
672
+
673
+ if __name__ == "__main__":
674
+ from torch._inductor.wrapper_benchmark import compiled_module_main
675
+ compiled_module_main('None', benchmark_compiled_module)
SpecForge-ext/cache/compiled_kernels/3t/c3tsudqpzzym4mczuvujkiocbjfkpeu64fxstxm4zhnfoe575tz5.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.reduction(
11
+ size_hints={'x': 131072, 'r0_': 128},
12
+ reduction_hint=ReductionHint.OUTER,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_mul_sum_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
16
+ )
17
+ @triton.jit
18
+ def triton_red_fused__to_copy_mul_sum_0(in_ptr0, in_ptr1, in_ptr2, out_ptr0, ks0, ks1, ks2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
19
+ rnumel = r0_numel
20
+ RBLOCK: tl.constexpr = R0_BLOCK
21
+ xoffset = tl.program_id(0) * XBLOCK
22
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
23
+ xmask = xindex < xnumel
24
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
25
+ rbase = r0_base
26
+ x1 = xindex // ks0
27
+ x0 = (xindex % ks0)
28
+ _tmp13 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
29
+ x3 = xindex
30
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
31
+ r0_index = r0_offset + r0_base
32
+ r0_mask = r0_index < r0_numel
33
+ roffset = r0_offset
34
+ rindex = r0_index
35
+ r0_2 = r0_index
36
+ tmp0 = r0_2 + x1*((31 + ks1*ks2) // 32)
37
+ tmp1 = ks1*ks2
38
+ tmp2 = tmp0 < tmp1
39
+ tmp3 = tl.load(in_ptr0 + (x0 + ks0*(((r0_2 + x1*((31 + ks1*ks2) // 32)) % (ks1*ks2)))), r0_mask & tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
40
+ tmp4 = tl.load(in_ptr1 + (x0 + ks0*(((r0_2 + x1*((31 + ks1*ks2) // 32)) % (ks1*ks2)))), r0_mask & tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
41
+ tmp5 = tmp4.to(tl.float32)
42
+ tmp6 = tl.load(in_ptr2 + (((r0_2 + x1*((31 + ks1*ks2) // 32)) % (ks1*ks2))), r0_mask & tmp2 & xmask, eviction_policy='evict_last', other=0.0)
43
+ tmp7 = tmp5 * tmp6
44
+ tmp8 = tmp7.to(tl.float32)
45
+ tmp9 = tmp3 * tmp8
46
+ tmp10 = tl.full(tmp9.shape, 0, tmp9.dtype)
47
+ tmp11 = tl.where(tmp2, tmp9, tmp10)
48
+ tmp12 = tl.broadcast_to(tmp11, [XBLOCK, R0_BLOCK])
49
+ tmp14 = _tmp13 + tmp12
50
+ _tmp13 = tl.where(r0_mask & xmask, tmp14, _tmp13)
51
+ tmp13 = tl.sum(_tmp13, 1)[:, None]
52
+ tl.store(out_ptr0 + (x3), tmp13, xmask)
SpecForge-ext/cache/compiled_kernels/4a/c4al3iqkp6tuwbxkuvfdhsroylad5b7vhjzbcuo4bmvqojqx55a6.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.persistent_reduction(
11
+ size_hints={'x': 4096, 'r0_': 32},
12
+ reduction_hint=ReductionHint.OUTER,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*bf16', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_mul_sum_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
16
+ )
17
+ @triton.jit
18
+ def triton_per_fused__to_copy_mul_sum_1(in_ptr0, out_ptr0, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr):
19
+ r0_numel = 32
20
+ R0_BLOCK: tl.constexpr = 32
21
+ rnumel = r0_numel
22
+ RBLOCK: tl.constexpr = R0_BLOCK
23
+ xoffset = tl.program_id(0) * XBLOCK
24
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
25
+ xmask = xindex < xnumel
26
+ r0_index = tl.arange(0, R0_BLOCK)[None, :]
27
+ r0_offset = 0
28
+ r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
29
+ roffset = r0_offset
30
+ rindex = r0_index
31
+ r0_1 = r0_index
32
+ x0 = xindex
33
+ tmp0 = tl.load(in_ptr0 + (x0 + ks0*r0_1), xmask, other=0.0)
34
+ tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
35
+ tmp3 = tl.where(xmask, tmp1, 0)
36
+ tmp4 = tl.sum(tmp3, 1)[:, None].to(tl.float32)
37
+ tl.store(out_ptr0 + (x0), tmp4, xmask)
SpecForge-ext/cache/compiled_kernels/4f/454d8d353d28ad90c99c8953cfbd86dfbda71629c2e83398709dc784450ea2cc.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 1, "num_warps": 2, "num_stages": 1, "configs_hash": "b6ac5ef64fddcad8fc8d2c05fa12424871fd9baa5a4158ff38ecebbafb55a4b1", "found_by_coordesc": false, "time_taken_ms": 26, "triton_cache_hash": "E2MI47QNGZ2SJDA3U3EKHN7H3EYRAANF6T7N5SFT2CZJYNBAWCNQ"}
SpecForge-ext/cache/compiled_kernels/4f/c4ft2b47ctfnp5zp5apvq5kvdlqubdrkzxpqndsh5oasyfr4v7y7.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.persistent_reduction(
11
+ size_hints={'x': 128, 'r0_': 16},
12
+ reduction_hint=ReductionHint.DEFAULT,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr2': '*i32', 'out_ptr3': '*i32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 1024, 'r0_': 16384}}
16
+ )
17
+ @triton.jit
18
+ def triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3(in_ptr0, out_ptr2, out_ptr3, xnumel, r0_numel, XBLOCK : tl.constexpr):
19
+ xnumel = 128
20
+ r0_numel = 16
21
+ R0_BLOCK: tl.constexpr = 16
22
+ rnumel = r0_numel
23
+ RBLOCK: tl.constexpr = R0_BLOCK
24
+ xoffset = tl.program_id(0) * XBLOCK
25
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
26
+ xmask = xindex < xnumel
27
+ r0_index = tl.arange(0, R0_BLOCK)[None, :]
28
+ r0_offset = 0
29
+ r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
30
+ roffset = r0_offset
31
+ rindex = r0_index
32
+ r0_2 = r0_index
33
+ x0 = (xindex % 16)
34
+ x1 = xindex // 16
35
+ x3 = xindex
36
+ tmp0 = tl.load(in_ptr0 + (x0 + 17*r0_2 + 272*x1), xmask, other=0.0)
37
+ tmp1 = r0_2
38
+ tmp2 = tmp1.to(tl.int16)
39
+ tmp3 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
40
+ tmp4 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK])
41
+ tmp5, tmp6, = triton_helpers.sort_with_index(tmp3, tmp4, None, 1, stable=True, descending=True)
42
+ tmp7 = tmp0.to(tl.int64)
43
+ tmp8 = tl.broadcast_to(tmp7, [XBLOCK, R0_BLOCK])
44
+ tmp10 = tl.where(xmask, tmp8, 0)
45
+ tmp11 = tl.sum(tmp10, 1)[:, None].to(tl.int64)
46
+ tmp12 = tmp6.to(tl.int64)
47
+ tmp13 = tmp12.to(tl.int32)
48
+ tmp14 = tmp11.to(tl.int32)
49
+ tl.store(out_ptr2 + (r0_2 + 16*x3), tmp13, xmask)
50
+ tl.store(out_ptr3 + (x3), tmp14, xmask)
SpecForge-ext/cache/compiled_kernels/4f/c4ftkcyg442lwmtmm6lclyxflgi5xjez7jaopr447jjiva2hmpax.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AOT ID: ['11_inference']
2
+ from ctypes import c_void_p, c_long, c_int
3
+ import torch
4
+ import math
5
+ import random
6
+ import os
7
+ import tempfile
8
+ from math import inf, nan
9
+ from cmath import nanj
10
+ from torch._inductor.hooks import run_intermediate_hooks
11
+ from torch._inductor.utils import maybe_profile
12
+ from torch._inductor.codegen.memory_planning import _align as align
13
+ from torch import device, empty_strided
14
+ from torch._inductor.async_compile import AsyncCompile
15
+ from torch._inductor.select_algorithm import extern_kernels
16
+ import triton
17
+ import triton.language as tl
18
+ from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
19
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
20
+
21
+ aten = torch.ops.aten
22
+ inductor_ops = torch.ops.inductor
23
+ _quantized = torch.ops._quantized
24
+ assert_size_stride = torch._C._dynamo.guards.assert_size_stride
25
+ assert_alignment = torch._C._dynamo.guards.assert_alignment
26
+ empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
27
+ empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
28
+ empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
29
+ empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
30
+ empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
31
+ reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
32
+ alloc_from_pool = torch.ops.inductor._alloc_from_pool
33
+ async_compile = AsyncCompile()
34
+ empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
35
+
36
+
37
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/2s/c2sasa5yimiwlxmywmcvgtuh2fvol2mvhppzairkbqvuwicnbd5y.py
38
+ # Topologically Sorted Source Nodes: [target_head, target_p], Original ATen: [aten._to_copy, prims.prepare_softmax_online, aten.sub, aten.exp, aten._softmax]
39
+ # Source node to ATen node mapping:
40
+ # target_head => convert_element_type
41
+ # target_p => div
42
+ # Graph fragment:
43
+ # %arg1_1 : Tensor "bf16[2, s67, 32000][32000*s67, 32000, 1]cuda:7" = PlaceHolder[target=arg1_1]
44
+ # %getitem : Tensor "f32[2, s67, 1][s67, 1, 2*s67]cuda:7" = PlaceHolder[target=getitem]
45
+ # %getitem_1 : Tensor "f32[2, s67, 1][s67, 1, 2*s67]cuda:7" = PlaceHolder[target=getitem_1]
46
+ # %convert_element_type : Tensor "f32[2, s67, 32000][32000*s67, 32000, 1]cuda:7"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%arg1_1, torch.float32), kwargs = {})
47
+ # %prepare_softmax_online_default : [num_users=2] = call_function[target=torch.ops.prims.prepare_softmax_online.default](args = (%convert_element_type, 2), kwargs = {})
48
+ # %sub_tensor : Tensor "f32[2, s67, 32000][32000*s67, 32000, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%convert_element_type, %getitem), kwargs = {})
49
+ # %exp_default : Tensor "f32[2, s67, 32000][32000*s67, 32000, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.exp.default](args = (%sub_tensor,), kwargs = {})
50
+ # %div : Tensor "f32[2, s67, 32000][32000*s67, 32000, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%exp_default, %getitem_1), kwargs = {})
51
+ # return %getitem,%getitem_1,%div
52
+ triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0 = async_compile.triton('triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0', '''
53
+ import triton
54
+ import triton.language as tl
55
+
56
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
57
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
58
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
59
+ triton_helpers.set_driver_to_gpu()
60
+
61
+ @triton_heuristics.reduction(
62
+ size_hints={'x': 4096, 'r0_': 32768},
63
+ reduction_hint=ReductionHint.INNER,
64
+ filename=__file__,
65
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr2': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
66
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
67
+ )
68
+ @triton.jit
69
+ def triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0(in_ptr0, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
70
+ r0_numel = 32000
71
+ rnumel = r0_numel
72
+ RBLOCK: tl.constexpr = R0_BLOCK
73
+ xoffset = tl.program_id(0) * XBLOCK
74
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
75
+ xmask = xindex < xnumel
76
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
77
+ rbase = r0_base
78
+ x0 = xindex
79
+ _tmp3_max = tl.full([XBLOCK, R0_BLOCK], float('-inf'), tl.float32)
80
+ _tmp3_sum = tl.zeros([XBLOCK, R0_BLOCK], tl.float32)
81
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
82
+ r0_index = r0_offset + r0_base
83
+ r0_mask = r0_index < r0_numel
84
+ roffset = r0_offset
85
+ rindex = r0_index
86
+ r0_1 = r0_index
87
+ tmp0 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
88
+ tmp1 = tmp0.to(tl.float32)
89
+ tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK])
90
+
91
+ _tmp3_max_next, _tmp3_sum_next = triton_helpers.online_softmax_combine(
92
+ _tmp3_max, _tmp3_sum, tmp2, False
93
+ )
94
+
95
+ _tmp3_max = tl.where(r0_mask & xmask, _tmp3_max_next, _tmp3_max)
96
+ _tmp3_sum = tl.where(r0_mask & xmask, _tmp3_sum_next, _tmp3_sum)
97
+
98
+ tmp3, tmp4 = triton_helpers.online_softmax_reduce(
99
+ _tmp3_max, _tmp3_sum, 1, False)
100
+ tmp3 = tmp3[:, None]
101
+ tmp4 = tmp4[:, None]
102
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
103
+ r0_index = r0_offset + r0_base
104
+ r0_mask = r0_index < r0_numel
105
+ roffset = r0_offset
106
+ rindex = r0_index
107
+ r0_1 = r0_index
108
+ tmp5 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
109
+ tmp6 = tmp5.to(tl.float32)
110
+ tmp7 = tmp6 - tmp3
111
+ tmp8 = libdevice.exp(tmp7)
112
+ tmp9 = (tmp8 / tmp4)
113
+ tl.store(out_ptr2 + (r0_1 + 32000*x0), tmp9, r0_mask & xmask)
114
+ ''', device_str='cuda')
115
+
116
+
117
+ async_compile.wait(globals())
118
+ del async_compile
119
+
120
+ class Runner:
121
+ def __init__(self, partitions):
122
+ self.partitions = partitions
123
+
124
+ def recursively_apply_fns(self, fns):
125
+ new_callables = []
126
+ for fn, c in zip(fns, self.partitions):
127
+ new_callables.append(fn(c))
128
+ self.partitions = new_callables
129
+
130
+ def call(self, args):
131
+ arg0_1, arg1_1 = args
132
+ args.clear()
133
+ s67 = arg0_1
134
+ assert_size_stride(arg1_1, (2, s67, 32000), (32000*s67, 32000, 1))
135
+ with torch.cuda._DeviceGuard(7):
136
+ torch.cuda.set_device(7)
137
+ buf2 = empty_strided_cuda((2, s67, 32000), (32000*s67, 32000, 1), torch.float32)
138
+ # Topologically Sorted Source Nodes: [target_head, target_p], Original ATen: [aten._to_copy, prims.prepare_softmax_online, aten.sub, aten.exp, aten._softmax]
139
+ triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0_xnumel = 2*s67
140
+ stream7 = get_raw_stream(7)
141
+ triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0.run(arg1_1, buf2, triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0_xnumel, 32000, stream=stream7)
142
+ del arg1_1
143
+ return (buf2, )
144
+
145
+ runner = Runner(partitions=[])
146
+ call = runner.call
147
+ recursively_apply_fns = runner.recursively_apply_fns
148
+
149
+
150
+ def benchmark_compiled_module(times=10, repeat=10):
151
+ from torch._dynamo.testing import rand_strided
152
+ from torch._inductor.utils import print_performance
153
+ arg0_1 = 1904
154
+ arg1_1 = rand_strided((2, 1904, 32000), (60928000, 32000, 1), device='cuda:7', dtype=torch.bfloat16)
155
+ fn = lambda: call([arg0_1, arg1_1])
156
+ return print_performance(fn, times=times, repeat=repeat)
157
+
158
+
159
+ if __name__ == "__main__":
160
+ from torch._inductor.wrapper_benchmark import compiled_module_main
161
+ compiled_module_main('None', benchmark_compiled_module)
SpecForge-ext/cache/compiled_kernels/4f/c4fwwpijdyl5egtippb7rggm43z2kiggh4onk7xkd7o5v7vfl3c7.py ADDED
@@ -0,0 +1,1051 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AOT ID: ['6_backward']
2
+ from ctypes import c_void_p, c_long, c_int
3
+ import torch
4
+ import math
5
+ import random
6
+ import os
7
+ import tempfile
8
+ from math import inf, nan
9
+ from cmath import nanj
10
+ from torch._inductor.hooks import run_intermediate_hooks
11
+ from torch._inductor.utils import maybe_profile
12
+ from torch._inductor.codegen.memory_planning import _align as align
13
+ from torch import device, empty_strided
14
+ from torch._inductor.async_compile import AsyncCompile
15
+ from torch._inductor.select_algorithm import extern_kernels
16
+ import triton
17
+ import triton.language as tl
18
+ from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
19
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
20
+
21
+ aten = torch.ops.aten
22
+ inductor_ops = torch.ops.inductor
23
+ _quantized = torch.ops._quantized
24
+ assert_size_stride = torch._C._dynamo.guards.assert_size_stride
25
+ assert_alignment = torch._C._dynamo.guards.assert_alignment
26
+ empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
27
+ empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
28
+ empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
29
+ empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
30
+ empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
31
+ reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
32
+ alloc_from_pool = torch.ops.inductor._alloc_from_pool
33
+ async_compile = AsyncCompile()
34
+ empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
35
+
36
+
37
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/th/cthv5zc2es46ngo2febwflavdqzw5qdaig35rrejlvqiistqzhbc.py
38
+ # Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros]
39
+ # Source node to ATen node mapping:
40
+ # Graph fragment:
41
+ # %getitem : Tensor "bf16[8, 32, 2048, 128][8388608, 128, 4096, 1]cuda:4" = PlaceHolder[target=getitem]
42
+ # %tangents_1 : Tensor "bf16[8, 32, 2048, 128][8388608, 262144, 128, 1]cuda:4" = PlaceHolder[target=tangents_1]
43
+ # %buf0 : Tensor "bf16[8, 32, 2048][65536, 2048, 1]cuda:4" = PlaceHolder[target=buf0]
44
+ # %full_default : Tensor "f32[8, 32, 2048][65536, 2048, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 32, 2048], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:4, pin_memory: False})
45
+ # %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_1, %primals_2, %primals_3, %getitem, %getitem_1, %tangents_1, %full_default, %fw_graph0, %joint_graph0, (2048, 2048, %primals_5, %primals_4, %primals_7, %primals_8, %primals_9, %primals_10, %primals_11, %primals_12, 128, 128, %mask_graph0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_6,)), kwargs = {})
46
+ # return %buf0,%buf1
47
+ triton_red_fused_zeros_0 = async_compile.triton('triton_red_fused_zeros_0', '''
48
+ import triton
49
+ import triton.language as tl
50
+
51
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
52
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
53
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
54
+ triton_helpers.set_driver_to_gpu()
55
+
56
+ @triton_heuristics.reduction(
57
+ size_hints={'x': 524288, 'r0_': 128},
58
+ reduction_hint=ReductionHint.DEFAULT,
59
+ filename=__file__,
60
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr1': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
61
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 4194304, 'r0_': 268435456}}
62
+ )
63
+ @triton.jit
64
+ def triton_red_fused_zeros_0(in_ptr0, in_ptr1, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
65
+ xnumel = 524288
66
+ r0_numel = 128
67
+ rnumel = r0_numel
68
+ RBLOCK: tl.constexpr = R0_BLOCK
69
+ xoffset = tl.program_id(0) * XBLOCK
70
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
71
+ xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
72
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
73
+ rbase = r0_base
74
+ x0 = (xindex % 2048)
75
+ x1 = ((xindex // 2048) % 32)
76
+ x2 = xindex // 65536
77
+ x4 = xindex
78
+ _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
79
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
80
+ r0_index = r0_offset + r0_base
81
+ r0_mask = r0_index < r0_numel
82
+ roffset = r0_offset
83
+ rindex = r0_index
84
+ r0_3 = r0_index
85
+ tmp0 = tl.load(in_ptr0 + (r0_3 + 128*x1 + 4096*x0 + 8388608*x2), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
86
+ tmp1 = tl.load(in_ptr1 + (r0_3 + 128*x4), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
87
+ tmp2 = tmp0 * tmp1
88
+ tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK])
89
+ tmp5 = _tmp4 + tmp3
90
+ _tmp4 = tl.where(r0_mask, tmp5, _tmp4)
91
+ tmp4 = tl.sum(_tmp4, 1)[:, None]
92
+ tmp6 = tmp4.to(tl.float32)
93
+ tmp7 = 0.0
94
+ tmp8 = tmp6 - tmp7
95
+ tl.store(out_ptr1 + (x4), tmp8, None)
96
+ ''', device_str='cuda')
97
+
98
+
99
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/cd/ccdjjsfw55ltptywulr7d4uka6bugxyoxsqibf4etcchr62jyb3f.py
100
+ # Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros]
101
+ # Source node to ATen node mapping:
102
+ # Graph fragment:
103
+ # %primals_1 : Tensor "bf16[8, 32, 2048, 128][8388608, 128, 4096, 1]cuda:4" = PlaceHolder[target=primals_1]
104
+ # %primals_2 : Tensor "bf16[8, 8, 2048, 128][2097152, 262144, 128, 1]cuda:4" = PlaceHolder[target=primals_2]
105
+ # %primals_3 : Tensor "bf16[8, 8, 2048, 128][2097152, 262144, 128, 1]cuda:4" = PlaceHolder[target=primals_3]
106
+ # %getitem_1 : Tensor "f32[8, 32, 2048][65536, 2048, 1]cuda:4" = PlaceHolder[target=getitem_1]
107
+ # %buf1 : Tensor "f32[8, 32, 2048][65536, 2048, 1]cuda:4" = PlaceHolder[target=buf1]
108
+ # %tangents_1 : Tensor "bf16[8, 32, 2048, 128][8388608, 262144, 128, 1]cuda:4" = PlaceHolder[target=tangents_1]
109
+ # %getitem_3 : Tensor "bf16[8, 32, 2048, 128][8388608, 128, 4096, 1]cuda:4" = PlaceHolder[target=getitem_3]
110
+ # %getitem_5 : Tensor "bf16[8, 8, 2048, 128][2097152, 262144, 128, 1]cuda:4" = PlaceHolder[target=getitem_5]
111
+ # %primals_5 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:4" = PlaceHolder[target=primals_5]
112
+ # %primals_4 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:4" = PlaceHolder[target=primals_4]
113
+ # %primals_9 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:4" = PlaceHolder[target=primals_9]
114
+ # %primals_10 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:4" = PlaceHolder[target=primals_10]
115
+ # %primals_7 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:4" = PlaceHolder[target=primals_7]
116
+ # %primals_8 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:4" = PlaceHolder[target=primals_8]
117
+ # %primals_11 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:4" = PlaceHolder[target=primals_11]
118
+ # %primals_12 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:4" = PlaceHolder[target=primals_12]
119
+ # %primals_6 : Tensor "i64[8][1]cuda:4" = PlaceHolder[target=primals_6]
120
+ # %full_default : Tensor "f32[8, 32, 2048][65536, 2048, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 32, 2048], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:4, pin_memory: False})
121
+ # %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_1, %primals_2, %primals_3, %getitem, %getitem_1, %tangents_1, %full_default, %fw_graph0, %joint_graph0, (2048, 2048, %primals_5, %primals_4, %primals_7, %primals_8, %primals_9, %primals_10, %primals_11, %primals_12, 128, 128, %mask_graph0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_6,)), kwargs = {})
122
+ # return %getitem_4
123
+ triton_tem_fused_zeros_1 = async_compile.triton('triton_tem_fused_zeros_1', '''
124
+ import triton
125
+ import triton.language as tl
126
+
127
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
128
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
129
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
130
+
131
+ @triton_heuristics.template(
132
+
133
+ num_stages=3,
134
+ num_warps=8,
135
+ triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]},
136
+ inductor_meta={'kernel_name': 'triton_tem_fused_zeros_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
137
+
138
+ )
139
+ @triton.jit
140
+ def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0):
141
+ PRESCALE_QK : tl.constexpr = False
142
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
143
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
144
+ WRITE_DQ : tl.constexpr = True
145
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
146
+ OUTPUT_MAX : tl.constexpr = False
147
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
148
+ IS_DIVISIBLE : tl.constexpr = True
149
+ SM_SCALE : tl.constexpr = 0.08838834764831843
150
+ GQA_SHARED_HEADS : tl.constexpr = 4
151
+ HAS_FULL_BLOCKS : tl.constexpr = True
152
+ QK_HEAD_DIM : tl.constexpr = 128
153
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
154
+ V_HEAD_DIM : tl.constexpr = 128
155
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
156
+ SAFE_HEAD_DIM : tl.constexpr = True
157
+ BLOCK_M1 : tl.constexpr = 64
158
+ BLOCK_N1 : tl.constexpr = 128
159
+ BLOCK_M2 : tl.constexpr = 128
160
+ BLOCK_N2 : tl.constexpr = 64
161
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
162
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
163
+ INDEX_DTYPE : tl.constexpr = tl.int32
164
+ Q = arg_Q
165
+ K = arg_K
166
+ V = arg_V
167
+ LSE = arg_LSE
168
+ DELTA = arg_DELTA
169
+ DO = arg_DO
170
+ DQ = arg_DQ
171
+ DV = arg_DV
172
+ KV_NUM_BLKS = arg_KV_NUM_BLKS
173
+ KV_IDX = arg_KV_IDX
174
+ Q_NUM_BLKS = arg_Q_NUM_BLKS
175
+ Q_IDX = arg_Q_IDX
176
+ FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
177
+ FULL_KV_IDX = arg_FULL_KV_IDX
178
+ FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS
179
+ FULL_Q_IDX = arg_FULL_Q_IDX
180
+
181
+ # Sub notation for this kernel:
182
+ #
183
+ # Q: Query, K: Key, V: Value
184
+ # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype)
185
+ # DELTA: Precomputed sum(OUT*DO, axis=-1)
186
+ # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value
187
+ # DK: Derivative of Key, is the written to via the store_output call due to some limitations with
188
+ # inductor codegen
189
+ # M: Number of queries, N: Number of keys/values
190
+ # QK_HEAD_DIM: The dimension of the query and key embeddings
191
+ # V_HEAD_DIM: The dimension of the value embeddings
192
+ # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim
193
+ # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
194
+ # (Modifiable) Performance tuning options
195
+ # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block.
196
+ # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V.
197
+ # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q.
198
+ # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block.
199
+ #
200
+ # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
201
+ # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
202
+ # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
203
+ # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query.
204
+ # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query.
205
+ # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
206
+ # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
207
+ # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query.
208
+ # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query.
209
+
210
+ # The below are kernel options that can be applied for certain score_mods,
211
+ # or involve a numerics vs. perf tradeoff
212
+ # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
213
+ # about 20% more numerical error, but slightly faster.
214
+
215
+ # Define strides of inputs
216
+ stride_qz, stride_qh, stride_qm, stride_qd = 8388608, 128, 4096, 1
217
+ stride_kz, stride_kh, stride_kn, stride_kd = 2097152, 262144, 128, 1
218
+ stride_vz, stride_vh, stride_vn, stride_vd = 2097152, 262144, 128, 1
219
+ stride_doz, stride_doh, stride_dom, stride_dod = 8388608, 262144, 128, 1
220
+
221
+ stride_dqz, stride_dqh, stride_dqm, stride_dqd = 8388608, 128, 4096, 1
222
+ stride_dvz, stride_dvh, stride_dvm, stride_dvd = 2097152, 262144, 128, 1
223
+
224
+ ZQ = 8
225
+ HQ = 32
226
+ HKV = 8
227
+ Q_LEN = 2048
228
+ ZKV = 8
229
+ KV_LEN = 2048
230
+
231
+ MATMUL_PRECISION = Q.dtype.element_ty
232
+
233
+ pid = tl.program_id(0).to(INDEX_DTYPE)
234
+ NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1)
235
+ NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2)
236
+
237
+ off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx
238
+ off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx
239
+ off_zkv = off_zq % ZKV # kv batch idx
240
+
241
+ SPARSE_Z = 8
242
+ SPARSE_HQ = 1
243
+
244
+ sparse_idx_z = off_zq % SPARSE_Z
245
+
246
+ k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64)
247
+ v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64)
248
+ # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
249
+ # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
250
+ dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64)
251
+
252
+ # offset K, V, DV pointers for batch/kv-head
253
+ K += k_adj
254
+ V += v_adj
255
+ DV += dv_adj
256
+
257
+ RCP_LN2 = 1.44269504
258
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
259
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
260
+
261
+ if pid >= NUM_KV_BLOCKS:
262
+ off_pid = pid - NUM_KV_BLOCKS
263
+ # THIS BLOCK DOES DQ
264
+ SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2)
265
+ SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
266
+ off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS
267
+ start_m2_block = off_pid % NUM_Q_BLOCKS
268
+ off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE
269
+ stride_kv_num_blks_h = 16
270
+ stride_kv_idx_h = 256
271
+ stride_kv_idx_m = 16
272
+
273
+ sparse_idx_hq2 = off_hq2 % SPARSE_HQ
274
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2
275
+
276
+ sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask
277
+ sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950
278
+
279
+ # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
280
+ q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64)
281
+ do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64)
282
+ dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64)
283
+ off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64)
284
+
285
+ Q2 = Q + q_adj2
286
+ DO2 = DO + do_adj2
287
+ # TODO: This does not work if DQ is not the same layout as Q (for example,
288
+ # if Q is broadcasted)
289
+ DQ2 = DQ + dq_adj2
290
+ LSE2 = LSE + off_chz2
291
+ DELTA2 = DELTA + off_chz2
292
+
293
+ # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32)
294
+ dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
295
+
296
+ start_m2 = start_m2_block * BLOCK_M2
297
+ offs_m2 = start_m2 + tl.arange(0, BLOCK_M2)
298
+
299
+ # load Q and do: they stay in SRAM throughout the inner loop.
300
+ q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
301
+ do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
302
+
303
+ if PRESCALE_QK:
304
+ q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
305
+
306
+ if IS_DIVISIBLE:
307
+ Di = tl.load(DELTA2 + offs_m2)
308
+ lse = tl.load(LSE2 + offs_m2)
309
+ else:
310
+ Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN)
311
+ lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN)
312
+ lse = tl.where(lse == -float("inf"), 0.0, lse)
313
+ lse = lse[:, None]
314
+
315
+ # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
316
+ # KV_IDX and KV_NUM_BLKS are always contiguous.
317
+ kv_indices = KV_IDX + sparse_kv_idx_offset
318
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
319
+ sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
320
+
321
+ offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
322
+ dq = bwd_dq_inner(
323
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
324
+ K, V,
325
+ dq, q, do, Di, lse,
326
+ off_zq, off_hq2, offs_m2, offs_n2,
327
+ stride_kn, stride_kd, stride_vn, stride_vd,
328
+ kv_indices, sparse_kv_num_blocks,
329
+ MATMUL_PRECISION,
330
+ IS_FULL_BLOCKS=False,
331
+ )
332
+
333
+ if HAS_FULL_BLOCKS:
334
+ # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
335
+ # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
336
+ kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
337
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
338
+ sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
339
+
340
+ offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
341
+ dq = bwd_dq_inner(
342
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
343
+ K, V,
344
+ dq, q, do, Di, lse,
345
+ off_zq, off_hq2, offs_m2, offs_n2,
346
+ stride_kn, stride_kd, stride_vn, stride_vd,
347
+ kv_indices, sparse_kv_num_blocks,
348
+ MATMUL_PRECISION,
349
+ IS_FULL_BLOCKS=True,
350
+ )
351
+
352
+ # Write back dQ.
353
+ dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd
354
+ dq *= SM_SCALE
355
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
356
+ tl.store(dq_ptrs, dq)
357
+ else:
358
+ tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM))
359
+ else:
360
+ # THIS BLOCK DOES DK & DV
361
+ SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
362
+ SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1)
363
+
364
+ pid_mask = pid // SPARSE_KV_MULTIPLE
365
+
366
+ stride_q_num_blks_h = 16
367
+ stride_q_idx_h = 256
368
+ stride_q_idx_n = 16
369
+
370
+
371
+ dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
372
+ dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
373
+
374
+ start_n1 = pid * BLOCK_N1
375
+ offs_n1 = start_n1 + tl.arange(0, BLOCK_N1)
376
+
377
+ # load K and V: they stay in SRAM throughout the inner loop.
378
+ k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
379
+ v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
380
+
381
+ if PRESCALE_QK:
382
+ k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
383
+
384
+ for off_g in range(0, GQA_SHARED_HEADS):
385
+ off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g
386
+
387
+ # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
388
+ q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64)
389
+ do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64)
390
+ dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64)
391
+ off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64)
392
+
393
+ Q1 = Q + q_adj1
394
+ DO1 = DO + do_adj1
395
+ # TODO: This does not work if DQ is not the same layout as Q (for example,
396
+ # if Q is broadcasted)
397
+ LSE1 = LSE + off_chz1
398
+ DELTA1 = DELTA + off_chz1
399
+
400
+ sparse_idx_hq1 = off_hq1 % SPARSE_HQ
401
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1
402
+
403
+ sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask
404
+ sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950
405
+
406
+ # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
407
+ # Q_IDX and Q_NUM_BLKS are always contiguous.
408
+ q_indices = Q_IDX + sparse_q_idx_offset
409
+ q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
410
+ sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset)
411
+
412
+ offs_m1 = q_start + tl.arange(0, BLOCK_M1)
413
+ dk, dv = bwd_dkdv_inner(
414
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
415
+ Q1, DO1, DELTA1, LSE1,
416
+ dk, dv, k, v,
417
+ off_zq, off_hq1, offs_n1, offs_m1,
418
+ stride_qm, stride_qd, stride_dom, stride_dod,
419
+ q_indices, sparse_q_num_blocks,
420
+ MATMUL_PRECISION,
421
+ IS_FULL_BLOCKS=False,
422
+ )
423
+
424
+
425
+ if HAS_FULL_BLOCKS:
426
+ # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
427
+ # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous.
428
+ q_indices = FULL_Q_IDX + sparse_q_idx_offset
429
+ q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
430
+ sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset)
431
+
432
+ offs_m1 = q_start + tl.arange(0, BLOCK_M1)
433
+ dk, dv = bwd_dkdv_inner(
434
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
435
+ Q1, DO1, DELTA1, LSE1,
436
+ dk, dv, k, v,
437
+ off_zq, off_hq1, offs_n1, offs_m1,
438
+ stride_qm, stride_qd, stride_dom, stride_dod,
439
+ q_indices, sparse_q_num_blocks,
440
+ MATMUL_PRECISION,
441
+ IS_FULL_BLOCKS=True,
442
+ )
443
+
444
+ # Write back dV and dK.
445
+ dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd
446
+
447
+ index_n = offs_n1[:, None]
448
+ index_k = offs_k[None, :]
449
+ index_v = offs_v[None, :]
450
+
451
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
452
+ tl.store(dv_ptrs, dv)
453
+ else:
454
+ tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM))
455
+
456
+ dk *= SM_SCALE
457
+
458
+ if SAFE_HEAD_DIM:
459
+ mask = index_n < KV_LEN
460
+ else:
461
+ mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM)
462
+
463
+ # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
464
+ # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
465
+ tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED])
466
+ xindex = index_k + 128*index_n + 262144*off_hkv + 2097152*off_zq
467
+ tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask)
468
+
469
+ @triton.jit
470
+ def bwd_dq_inner(
471
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
472
+ K, V, # pointers
473
+ dq, q, do, Di, lse,
474
+ off_z, off_hq, offs_m2, offs_n2,
475
+ stride_kn, stride_kd, stride_vn, stride_vd,
476
+ kv_indices, sparse_kv_num_blocks,
477
+ MATMUL_PRECISION,
478
+ IS_FULL_BLOCKS,
479
+ ):
480
+ PRESCALE_QK : tl.constexpr = False
481
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
482
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
483
+ WRITE_DQ : tl.constexpr = True
484
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
485
+ OUTPUT_MAX : tl.constexpr = False
486
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
487
+ IS_DIVISIBLE : tl.constexpr = True
488
+ SM_SCALE : tl.constexpr = 0.08838834764831843
489
+ GQA_SHARED_HEADS : tl.constexpr = 4
490
+ HAS_FULL_BLOCKS : tl.constexpr = True
491
+ QK_HEAD_DIM : tl.constexpr = 128
492
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
493
+ V_HEAD_DIM : tl.constexpr = 128
494
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
495
+ SAFE_HEAD_DIM : tl.constexpr = True
496
+ BLOCK_M1 : tl.constexpr = 64
497
+ BLOCK_N1 : tl.constexpr = 128
498
+ BLOCK_M2 : tl.constexpr = 128
499
+ BLOCK_N2 : tl.constexpr = 64
500
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
501
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
502
+ INDEX_DTYPE : tl.constexpr = tl.int32
503
+
504
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
505
+ RCP_LN2: tl.constexpr = 1.44269504
506
+ Q_LEN = 2048
507
+ KV_LEN = 2048
508
+
509
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
510
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
511
+
512
+ kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd
513
+ vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd
514
+ # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
515
+ tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
516
+
517
+ hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1))
518
+
519
+ for start_n in range(0, hi):
520
+ dq = bwd_dq_block_mn(
521
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
522
+ dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
523
+ off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
524
+ stride_kn, stride_kd, stride_vn, stride_vd,
525
+ kv_indices, sparse_kv_num_blocks,
526
+ MATMUL_PRECISION, RCP_LN2,
527
+ IS_FULL_BLOCKS,
528
+ )
529
+
530
+ # Increment pointers.
531
+ offset = get_offset_for_next_block(
532
+ start_n, kv_indices, sparse_kv_num_blocks,
533
+ SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS
534
+ )
535
+
536
+ kT_ptrs += offset * stride_kn
537
+ vT_ptrs += offset * stride_vn
538
+
539
+ offs_n2 += offset
540
+
541
+ return dq
542
+
543
+
544
+ @triton.jit
545
+ def bwd_dq_block_mn(
546
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
547
+ dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
548
+ off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
549
+ stride_kn, stride_kd, stride_vn, stride_vd,
550
+ kv_indices, sparse_kv_num_blocks,
551
+ MATMUL_PRECISION, RCP_LN2,
552
+ IS_FULL_BLOCKS,
553
+ ):
554
+ PRESCALE_QK : tl.constexpr = False
555
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
556
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
557
+ WRITE_DQ : tl.constexpr = True
558
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
559
+ OUTPUT_MAX : tl.constexpr = False
560
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
561
+ IS_DIVISIBLE : tl.constexpr = True
562
+ SM_SCALE : tl.constexpr = 0.08838834764831843
563
+ GQA_SHARED_HEADS : tl.constexpr = 4
564
+ HAS_FULL_BLOCKS : tl.constexpr = True
565
+ QK_HEAD_DIM : tl.constexpr = 128
566
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
567
+ V_HEAD_DIM : tl.constexpr = 128
568
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
569
+ SAFE_HEAD_DIM : tl.constexpr = True
570
+ BLOCK_M1 : tl.constexpr = 64
571
+ BLOCK_N1 : tl.constexpr = 128
572
+ BLOCK_M2 : tl.constexpr = 128
573
+ BLOCK_N2 : tl.constexpr = 64
574
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
575
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
576
+ INDEX_DTYPE : tl.constexpr = tl.int32
577
+
578
+
579
+ # NB reversed order to since K is transposed
580
+ kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN)
581
+ qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION)
582
+ if not PRESCALE_QK:
583
+ qk *= SM_SCALE
584
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
585
+ pre_mod_scores = qk
586
+ n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None)
587
+ # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim
588
+ # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary
589
+ m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None)
590
+
591
+ tmp0 = (qk)
592
+ post_mod_scores = tmp0
593
+
594
+
595
+
596
+
597
+ if not IS_DIVISIBLE:
598
+ post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf"))
599
+
600
+ if not IS_FULL_BLOCKS:
601
+ tmp1 = tl.full([1], False, tl.int1)
602
+ tmp2 = (m)
603
+ tmp3 = (n)
604
+ tmp4 = tmp2 >= tmp3
605
+ tmp5 = tmp3.to(tl.int64)
606
+ tmp6 = (off_z)
607
+ tmp7 = tl.load(in_ptr16 + tmp6)
608
+ tmp8 = tmp5 < tmp7
609
+ tmp9 = tmp2.to(tl.int64)
610
+ tmp10 = tmp9 < tmp7
611
+ tmp11 = tmp8 & tmp10
612
+ tmp12 = tmp4 & tmp11
613
+ tmp13 = tmp1 | tmp12
614
+ tmp14 = tl.full([1], 2048, tl.int32)
615
+ tmp15 = tmp3 >= tmp14
616
+ tmp16 = (tmp3 % tmp14)
617
+ tmp17 = tl.full([1], 0, tl.int32)
618
+ tmp18 = tmp16 != tmp17
619
+ tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0
620
+ tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0
621
+ tmp21 = tmp19 != tmp20
622
+ tmp22 = tmp18 & tmp21
623
+ tmp23 = tmp16 + tmp14
624
+ tmp24 = tl.where(tmp22, tmp23, tmp16)
625
+ tmp25 = tmp24.to(tl.int64)
626
+ tmp26 = tmp25 < tmp7
627
+ tmp27 = tmp15 & tmp26
628
+ tmp28 = tmp3 - tmp2
629
+ tmp29 = (tmp28 % tmp14)
630
+ tmp30 = tmp29 != tmp17
631
+ tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0
632
+ tmp32 = tmp31 != tmp20
633
+ tmp33 = tmp30 & tmp32
634
+ tmp34 = tmp29 + tmp14
635
+ tmp35 = tl.where(tmp33, tmp34, tmp29)
636
+ tmp36 = tmp35 == tmp17
637
+ tmp37 = tmp27 & tmp36
638
+ tmp38 = tmp13 | tmp37
639
+ mask_mod_output = tmp38
640
+
641
+
642
+ # apply mask for partial masked block
643
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
644
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
645
+ if not PRESCALE_QK:
646
+ post_mod_scores *= RCP_LN2
647
+ p = tl.math.exp2(post_mod_scores - lse)
648
+ # Compute dP and dS.
649
+ # NB reversed order to since V is transposed
650
+ vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN)
651
+
652
+ dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION)
653
+ ds = p * (dp - Di[:, None])
654
+ # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
655
+ tmp39 = (ds)
656
+ grad_scores = tmp39
657
+
658
+
659
+ if not IS_DIVISIBLE:
660
+ grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0)
661
+
662
+ # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
663
+ if WRITE_DQ:
664
+ scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN)
665
+
666
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
667
+ ds = grad_scores
668
+
669
+ if not IS_FULL_BLOCKS:
670
+ # (grads) apply mask for partially unmasked block
671
+ ds = tl.where(mask_mod_output, ds, 0.0)
672
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
673
+ ds = ds.to(MATMUL_PRECISION)
674
+ # Compute dQ.
675
+ dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION)
676
+
677
+ return dq
678
+
679
+
680
+ @triton.jit
681
+ def bwd_dkdv_inner(
682
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
683
+ Q, DO, DELTA, LSE, # pointers
684
+ dk, dv, k, v,
685
+ off_z, off_hq, offs_n1, offs_m1,
686
+ stride_qm, stride_qd, stride_dom, stride_dod,
687
+ q_indices, sparse_q_num_blocks,
688
+ MATMUL_PRECISION,
689
+ IS_FULL_BLOCKS,
690
+ ):
691
+ PRESCALE_QK : tl.constexpr = False
692
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
693
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
694
+ WRITE_DQ : tl.constexpr = True
695
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
696
+ OUTPUT_MAX : tl.constexpr = False
697
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
698
+ IS_DIVISIBLE : tl.constexpr = True
699
+ SM_SCALE : tl.constexpr = 0.08838834764831843
700
+ GQA_SHARED_HEADS : tl.constexpr = 4
701
+ HAS_FULL_BLOCKS : tl.constexpr = True
702
+ QK_HEAD_DIM : tl.constexpr = 128
703
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
704
+ V_HEAD_DIM : tl.constexpr = 128
705
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
706
+ SAFE_HEAD_DIM : tl.constexpr = True
707
+ BLOCK_M1 : tl.constexpr = 64
708
+ BLOCK_N1 : tl.constexpr = 128
709
+ BLOCK_M2 : tl.constexpr = 128
710
+ BLOCK_N2 : tl.constexpr = 64
711
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
712
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
713
+ INDEX_DTYPE : tl.constexpr = tl.int32
714
+
715
+ SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
716
+ RCP_LN2: tl.constexpr = 1.44269504
717
+ Q_LEN = 2048
718
+ KV_LEN = 2048
719
+
720
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
721
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
722
+
723
+ qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd
724
+ do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod
725
+ # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
726
+ tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
727
+
728
+ # The minimum is needed to handle the case where we run with a super large
729
+ # SPARSE_BLOCK_SIZE (i.e. no block-mask!)
730
+ hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1))
731
+
732
+ for start_m in range(0, hi):
733
+ dk, dv = bwd_dkdv_block_mn(
734
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
735
+ dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
736
+ off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
737
+ stride_qm, stride_qd, stride_dom, stride_dod,
738
+ q_indices, sparse_q_num_blocks,
739
+ MATMUL_PRECISION, RCP_LN2,
740
+ IS_FULL_BLOCKS,
741
+ )
742
+ # Increment pointers.
743
+ offset = get_offset_for_next_block(
744
+ start_m, q_indices, sparse_q_num_blocks,
745
+ SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS
746
+ )
747
+
748
+ qT_ptrs += offset * stride_qm
749
+ do_ptrs += offset * stride_dom
750
+ offs_m1 += offset
751
+
752
+ return dk, dv
753
+
754
+
755
+ @triton.jit
756
+ def bwd_dkdv_block_mn(
757
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
758
+ dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
759
+ off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
760
+ stride_qm, stride_qd, stride_dom, stride_dod,
761
+ q_indices, sparse_q_num_blocks,
762
+ MATMUL_PRECISION, RCP_LN2,
763
+ IS_FULL_BLOCKS,
764
+ ):
765
+ PRESCALE_QK : tl.constexpr = False
766
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
767
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
768
+ WRITE_DQ : tl.constexpr = True
769
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
770
+ OUTPUT_MAX : tl.constexpr = False
771
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
772
+ IS_DIVISIBLE : tl.constexpr = True
773
+ SM_SCALE : tl.constexpr = 0.08838834764831843
774
+ GQA_SHARED_HEADS : tl.constexpr = 4
775
+ HAS_FULL_BLOCKS : tl.constexpr = True
776
+ QK_HEAD_DIM : tl.constexpr = 128
777
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
778
+ V_HEAD_DIM : tl.constexpr = 128
779
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
780
+ SAFE_HEAD_DIM : tl.constexpr = True
781
+ BLOCK_M1 : tl.constexpr = 64
782
+ BLOCK_N1 : tl.constexpr = 128
783
+ BLOCK_M2 : tl.constexpr = 128
784
+ BLOCK_N2 : tl.constexpr = 64
785
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
786
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
787
+ INDEX_DTYPE : tl.constexpr = tl.int32
788
+
789
+
790
+ # NB reversed order since Q is transposed
791
+ qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN)
792
+ # Load LSE before computing qk to reduce pipeline stall.
793
+ if IS_DIVISIBLE:
794
+ lse = tl.load(LSE + offs_m1)
795
+ else:
796
+ lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN)
797
+ lse = tl.where(lse == -float("inf"), 0.0, lse)
798
+ qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION)
799
+ if not PRESCALE_QK:
800
+ qkT *= SM_SCALE
801
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
802
+ m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None)
803
+ # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim
804
+ # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary
805
+ n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None)
806
+
807
+ pre_mod_scores = qkT
808
+ tmp40 = (qkT)
809
+ post_mod_scores = tmp40
810
+
811
+
812
+
813
+ if not IS_DIVISIBLE:
814
+ post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf"))
815
+
816
+ if not IS_FULL_BLOCKS:
817
+ tmp41 = tl.full([1], False, tl.int1)
818
+ tmp42 = (m)
819
+ tmp43 = (n)
820
+ tmp44 = tmp42 >= tmp43
821
+ tmp45 = tmp43.to(tl.int64)
822
+ tmp46 = (off_z)
823
+ tmp47 = tl.load(in_ptr16 + tmp46)
824
+ tmp48 = tmp45 < tmp47
825
+ tmp49 = tmp42.to(tl.int64)
826
+ tmp50 = tmp49 < tmp47
827
+ tmp51 = tmp48 & tmp50
828
+ tmp52 = tmp44 & tmp51
829
+ tmp53 = tmp41 | tmp52
830
+ tmp54 = tl.full([1], 2048, tl.int32)
831
+ tmp55 = tmp43 >= tmp54
832
+ tmp56 = (tmp43 % tmp54)
833
+ tmp57 = tl.full([1], 0, tl.int32)
834
+ tmp58 = tmp56 != tmp57
835
+ tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0
836
+ tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0
837
+ tmp61 = tmp59 != tmp60
838
+ tmp62 = tmp58 & tmp61
839
+ tmp63 = tmp56 + tmp54
840
+ tmp64 = tl.where(tmp62, tmp63, tmp56)
841
+ tmp65 = tmp64.to(tl.int64)
842
+ tmp66 = tmp65 < tmp47
843
+ tmp67 = tmp55 & tmp66
844
+ tmp68 = tmp43 - tmp42
845
+ tmp69 = (tmp68 % tmp54)
846
+ tmp70 = tmp69 != tmp57
847
+ tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0
848
+ tmp72 = tmp71 != tmp60
849
+ tmp73 = tmp70 & tmp72
850
+ tmp74 = tmp69 + tmp54
851
+ tmp75 = tl.where(tmp73, tmp74, tmp69)
852
+ tmp76 = tmp75 == tmp57
853
+ tmp77 = tmp67 & tmp76
854
+ tmp78 = tmp53 | tmp77
855
+ mask_mod_output = tmp78
856
+
857
+ # (grads) apply mask for fully masked block
858
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
859
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
860
+ if not PRESCALE_QK:
861
+ post_mod_scores *= RCP_LN2
862
+ pT = tl.math.exp2(post_mod_scores - lse[None, :])
863
+ do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
864
+ # Compute dV.
865
+ ppT = pT
866
+ dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION)
867
+ if IS_DIVISIBLE:
868
+ Di = tl.load(DELTA + offs_m1)
869
+ else:
870
+ Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN)
871
+ # Compute dP and dS.
872
+ dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION)
873
+ dsT = pT * (dpT - Di[None, :])
874
+ # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
875
+ tmp79 = (dsT)
876
+ grad_scores = tmp79
877
+
878
+
879
+
880
+ if not IS_DIVISIBLE:
881
+ grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0)
882
+
883
+ # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
884
+ if not WRITE_DQ:
885
+ idx_b = off_z
886
+ idx_h = off_hq
887
+ idx_m = m
888
+ idx_n = n
889
+ scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN)
890
+
891
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
892
+ dsT = grad_scores
893
+ if not IS_FULL_BLOCKS:
894
+ # (grads) apply mask for partially unmasked block
895
+ dsT = tl.where(mask_mod_output, dsT, 0.0)
896
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
897
+ dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION)
898
+
899
+ return dk, dv
900
+
901
+ # Utility triton funcs
902
+ @triton.jit
903
+ def get_offset_for_next_block(
904
+ loop_iter, col_indices, total_blocks,
905
+ SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
906
+ BLOCKS_ARE_CONTIGUOUS: tl.constexpr
907
+ ):
908
+ if BLOCKS_ARE_CONTIGUOUS:
909
+ return BLOCK
910
+ cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
911
+ cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
912
+ next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
913
+ needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
914
+ jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
915
+ offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
916
+ return offset
917
+
918
+ @triton.jit
919
+ def get_bounded_indices(indices, max_len=None):
920
+ return indices % max_len if max_len is not None else indices
921
+
922
+ @triton.jit
923
+ def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
924
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
925
+ return tl.load(block_ptr)
926
+ elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
927
+ return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
928
+ elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
929
+ return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
930
+ else:
931
+ return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
932
+
933
+ @triton.jit
934
+ def load_checked_2d(
935
+ ptr,
936
+ offs_m,
937
+ offs_n,
938
+ stride_m,
939
+ stride_n,
940
+ IS_DIVISIBLE_M: tl.constexpr,
941
+ IS_DIVISIBLE_N: tl.constexpr,
942
+ M_LEN: tl.constexpr,
943
+ N_LEN: tl.constexpr,
944
+ ):
945
+ # Calculate final pointer if strides are provided
946
+ if stride_m is not None and stride_n is not None:
947
+ ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
948
+
949
+ # Handle all masking cases
950
+ if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
951
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
952
+ elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
953
+ return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
954
+ elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
955
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
956
+ else: # Both divisible
957
+ return tl.load(ptr)
958
+ ''', device_str='cuda')
959
+
960
+
961
+ async_compile.wait(globals())
962
+ del async_compile
963
+
964
+ class Runner:
965
+ def __init__(self, partitions):
966
+ self.partitions = partitions
967
+
968
+ def recursively_apply_fns(self, fns):
969
+ new_callables = []
970
+ for fn, c in zip(fns, self.partitions):
971
+ new_callables.append(fn(c))
972
+ self.partitions = new_callables
973
+
974
+ def call(self, args):
975
+ primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, getitem, getitem_1, tangents_1 = args
976
+ args.clear()
977
+ assert_size_stride(primals_1, (8, 32, 2048, 128), (8388608, 128, 4096, 1))
978
+ assert_size_stride(primals_2, (8, 8, 2048, 128), (2097152, 262144, 128, 1))
979
+ assert_size_stride(primals_3, (8, 8, 2048, 128), (2097152, 262144, 128, 1))
980
+ assert_size_stride(primals_4, (8, 1, 16, 16), (256, 256, 16, 1))
981
+ assert_size_stride(primals_5, (8, 1, 16), (16, 16, 1))
982
+ assert_size_stride(primals_6, (8, ), (1, ))
983
+ assert_size_stride(primals_7, (8, 1, 16), (16, 16, 1))
984
+ assert_size_stride(primals_8, (8, 1, 16, 16), (256, 256, 16, 1))
985
+ assert_size_stride(primals_9, (8, 1, 16), (16, 16, 1))
986
+ assert_size_stride(primals_10, (8, 1, 16, 16), (256, 256, 16, 1))
987
+ assert_size_stride(primals_11, (8, 1, 16), (16, 16, 1))
988
+ assert_size_stride(primals_12, (8, 1, 16, 16), (256, 256, 16, 1))
989
+ assert_size_stride(getitem, (8, 32, 2048, 128), (8388608, 128, 4096, 1))
990
+ assert_size_stride(getitem_1, (8, 32, 2048), (65536, 2048, 1))
991
+ assert_size_stride(tangents_1, (8, 32, 2048, 128), (8388608, 262144, 128, 1))
992
+ with torch.cuda._DeviceGuard(4):
993
+ torch.cuda.set_device(4)
994
+ buf1 = empty_strided_cuda((8, 32, 2048), (65536, 2048, 1), torch.float32)
995
+ # Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros]
996
+ stream4 = get_raw_stream(4)
997
+ triton_red_fused_zeros_0.run(getitem, tangents_1, buf1, 524288, 128, stream=stream4)
998
+ del getitem
999
+ buf3 = empty_strided_cuda((8, 32, 2048, 128), (8388608, 128, 4096, 1), torch.bfloat16)
1000
+ buf4 = empty_strided_cuda((8, 8, 2048, 128), (2097152, 262144, 128, 1), torch.bfloat16)
1001
+ buf5 = empty_strided_cuda((8, 8, 2048, 128), (2097152, 262144, 128, 1), torch.bfloat16)
1002
+ # Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros]
1003
+ stream4 = get_raw_stream(4)
1004
+ triton_tem_fused_zeros_1.run(primals_1, primals_2, primals_3, getitem_1, buf1, tangents_1, buf3, buf4, primals_5, primals_4, primals_9, primals_10, primals_7, primals_8, primals_11, primals_12, primals_6, buf5, 80, 8, 8, stream=stream4)
1005
+ del buf1
1006
+ del getitem_1
1007
+ del primals_1
1008
+ del primals_10
1009
+ del primals_11
1010
+ del primals_12
1011
+ del primals_2
1012
+ del primals_3
1013
+ del primals_4
1014
+ del primals_5
1015
+ del primals_6
1016
+ del primals_7
1017
+ del primals_8
1018
+ del primals_9
1019
+ del tangents_1
1020
+ return (buf3, buf5, buf4, None, None, None, None, None, None, None, None, None, )
1021
+
1022
+ runner = Runner(partitions=[])
1023
+ call = runner.call
1024
+ recursively_apply_fns = runner.recursively_apply_fns
1025
+
1026
+
1027
+ def benchmark_compiled_module(times=10, repeat=10):
1028
+ from torch._dynamo.testing import rand_strided
1029
+ from torch._inductor.utils import print_performance
1030
+ primals_1 = rand_strided((8, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:4', dtype=torch.bfloat16)
1031
+ primals_2 = rand_strided((8, 8, 2048, 128), (2097152, 262144, 128, 1), device='cuda:4', dtype=torch.bfloat16)
1032
+ primals_3 = rand_strided((8, 8, 2048, 128), (2097152, 262144, 128, 1), device='cuda:4', dtype=torch.bfloat16)
1033
+ primals_4 = rand_strided((8, 1, 16, 16), (256, 256, 16, 1), device='cuda:4', dtype=torch.int32)
1034
+ primals_5 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:4', dtype=torch.int32)
1035
+ primals_6 = rand_strided((8, ), (1, ), device='cuda:4', dtype=torch.int64)
1036
+ primals_7 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:4', dtype=torch.int32)
1037
+ primals_8 = rand_strided((8, 1, 16, 16), (256, 256, 16, 1), device='cuda:4', dtype=torch.int32)
1038
+ primals_9 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:4', dtype=torch.int32)
1039
+ primals_10 = rand_strided((8, 1, 16, 16), (256, 256, 16, 1), device='cuda:4', dtype=torch.int32)
1040
+ primals_11 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:4', dtype=torch.int32)
1041
+ primals_12 = rand_strided((8, 1, 16, 16), (256, 256, 16, 1), device='cuda:4', dtype=torch.int32)
1042
+ getitem = rand_strided((8, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:4', dtype=torch.bfloat16)
1043
+ getitem_1 = rand_strided((8, 32, 2048), (65536, 2048, 1), device='cuda:4', dtype=torch.float32)
1044
+ tangents_1 = rand_strided((8, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:4', dtype=torch.bfloat16)
1045
+ fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, getitem, getitem_1, tangents_1])
1046
+ return print_performance(fn, times=times, repeat=repeat)
1047
+
1048
+
1049
+ if __name__ == "__main__":
1050
+ from torch._inductor.wrapper_benchmark import compiled_module_main
1051
+ compiled_module_main('None', benchmark_compiled_module)
SpecForge-ext/cache/compiled_kernels/4i/9b9fb3b21587241e4ad8c181607f493e81c755cfbd40bac95f98eae271b2754d.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 1, "R0_BLOCK": 2048, "num_warps": 16, "num_stages": 1, "configs_hash": "50b7a7455b8a2aa7fe5b57654ddf092584f02f34b265601866fdd653f06a5539", "found_by_coordesc": false, "time_taken_ms": 73, "triton_cache_hash": "GEZC7BNCXFQAGCZIOI2BQLAAUGS4IVUJ4QGCDMFUE3MMZMGBMJIQ"}
SpecForge-ext/cache/compiled_kernels/4i/c4iwnhsf5kfmm7jnzrkyiv4x3yahjog6dyhf4prm2cjdi5xhllx2.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.reduction(
11
+ size_hints={'x': 16384, 'r0_': 32768},
12
+ reduction_hint=ReductionHint.INNER,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr2': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'add_persistent_rblock': True, 'tiling_scores': {'x': 0, 'r0_': 5242880000}}
16
+ )
17
+ @triton.jit
18
+ def triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0(in_ptr0, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
19
+ xnumel = 16384
20
+ r0_numel = 32000
21
+ rnumel = r0_numel
22
+ RBLOCK: tl.constexpr = R0_BLOCK
23
+ xoffset = tl.program_id(0) * XBLOCK
24
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
25
+ xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
26
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
27
+ rbase = r0_base
28
+ x0 = xindex
29
+ _tmp3_max = tl.full([XBLOCK, R0_BLOCK], float('-inf'), tl.float32)
30
+ _tmp3_sum = tl.zeros([XBLOCK, R0_BLOCK], tl.float32)
31
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
32
+ r0_index = r0_offset + r0_base
33
+ r0_mask = r0_index < r0_numel
34
+ roffset = r0_offset
35
+ rindex = r0_index
36
+ r0_1 = r0_index
37
+ tmp0 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
38
+ tmp1 = tmp0.to(tl.float32)
39
+ tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK])
40
+
41
+ _tmp3_max_next, _tmp3_sum_next = triton_helpers.online_softmax_combine(
42
+ _tmp3_max, _tmp3_sum, tmp2, False
43
+ )
44
+
45
+ _tmp3_max = tl.where(r0_mask, _tmp3_max_next, _tmp3_max)
46
+ _tmp3_sum = tl.where(r0_mask, _tmp3_sum_next, _tmp3_sum)
47
+
48
+ tmp3, tmp4 = triton_helpers.online_softmax_reduce(
49
+ _tmp3_max, _tmp3_sum, 1, False)
50
+ tmp3 = tmp3[:, None]
51
+ tmp4 = tmp4[:, None]
52
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
53
+ r0_index = r0_offset + r0_base
54
+ r0_mask = r0_index < r0_numel
55
+ roffset = r0_offset
56
+ rindex = r0_index
57
+ r0_1 = r0_index
58
+ tmp5 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
59
+ tmp6 = tmp5.to(tl.float32)
60
+ tmp7 = tmp6 - tmp3
61
+ tmp8 = libdevice.exp(tmp7)
62
+ tmp9 = (tmp8 / tmp4)
63
+ tl.store(out_ptr2 + (r0_1 + 32000*x0), tmp9, r0_mask)
SpecForge-ext/cache/compiled_kernels/4l/c4lbz3jtnjjxbp7lftpjy4iam6ao6fc5cpp42bxihe27bm4qlhss.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.persistent_reduction(
11
+ size_hints={'x': 1, 'r0_': 2},
12
+ reduction_hint=ReductionHint.INNER,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i64', 'out_ptr2': '*fp32', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {'xnumel': 1}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_clamp_min_div_eq_mul_squeeze_sum_4', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'r0_': 8}}
16
+ )
17
+ @triton.jit
18
+ def triton_per_fused_clamp_min_div_eq_mul_squeeze_sum_4(in_ptr0, in_ptr1, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr):
19
+ xnumel = 1
20
+ r0_numel = 2
21
+ R0_BLOCK: tl.constexpr = 2
22
+ rnumel = r0_numel
23
+ RBLOCK: tl.constexpr = R0_BLOCK
24
+ xoffset = tl.program_id(0) * XBLOCK
25
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
26
+ xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
27
+ r0_index = tl.arange(0, R0_BLOCK)[None, :]
28
+ r0_offset = 0
29
+ r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
30
+ roffset = r0_offset
31
+ rindex = r0_index
32
+ r0_0 = r0_index
33
+ tmp0 = tl.load(in_ptr0 + (r0_0), None)
34
+ tmp4 = tl.load(in_ptr1 + (r0_0), None)
35
+ tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
36
+ tmp3 = tl.sum(tmp1, 1)[:, None].to(tl.int64)
37
+ tmp5 = tl.broadcast_to(tmp4, [XBLOCK, R0_BLOCK])
38
+ tmp7 = tl.sum(tmp5, 1)[:, None].to(tl.int64)
39
+ tmp8 = tmp3.to(tl.float32)
40
+ tmp9 = tmp7.to(tl.float32)
41
+ tmp10 = 1e-06
42
+ tmp11 = triton_helpers.maximum(tmp9, tmp10)
43
+ tmp12 = (tmp8 / tmp11)
44
+ tl.store(out_ptr2 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp12, None)
SpecForge-ext/cache/compiled_kernels/4n/a4add0613c3c13d6644e27d4d0641afe951924b14998f7667d2b2ebdefe532f7.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 128, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 11, "triton_cache_hash": "X4SFCUNHNVK6FR6CSIUU4JIDJXVPMITMWOHHGKRF3QCUQNY7M77Q"}
SpecForge-ext/cache/compiled_kernels/4n/c4ntlraqki6522y3kmq7crnap6gq5asdu5huu7r2d7hvfkgash6w.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.pointwise(
11
+ size_hints={'x': 1024},
12
+ filename=__file__,
13
+ triton_meta={'signature': {'out_ptr0': '*i32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_new_zeros_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 4352}},
15
+ min_elem_per_thread=0
16
+ )
17
+ @triton.jit
18
+ def triton_poi_fused_new_zeros_1(out_ptr0, xnumel, XBLOCK : tl.constexpr):
19
+ xnumel = 544
20
+ xoffset = tl.program_id(0) * XBLOCK
21
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
22
+ xmask = xindex < xnumel
23
+ x0 = xindex
24
+ tmp0 = tl.full([1], 0, tl.int32)
25
+ tl.store(out_ptr0 + (x0), tmp0, xmask)
SpecForge-ext/cache/compiled_kernels/4v/c4v5ovh2xgazpxywsn665wlhmrlaz6snvnzzmii7gxagr7rjrhrr.py ADDED
@@ -0,0 +1,552 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+
9
+ @triton_heuristics.template(
10
+
11
+ num_stages=3,
12
+ num_warps=8,
13
+ triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'in_ptr9': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
15
+
16
+ )
17
+ @triton.jit
18
+ def triton_flex_attention(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5):
19
+ PRESCALE_QK : tl.constexpr = False
20
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
21
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
22
+ WRITE_DQ : tl.constexpr = True
23
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
24
+ OUTPUT_MAX : tl.constexpr = False
25
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
26
+ IS_DIVISIBLE : tl.constexpr = False
27
+ SM_SCALE : tl.constexpr = 0.08838834764831843
28
+ GQA_SHARED_HEADS : tl.constexpr = 4
29
+ HAS_FULL_BLOCKS : tl.constexpr = True
30
+ QK_HEAD_DIM : tl.constexpr = 128
31
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
32
+ V_HEAD_DIM : tl.constexpr = 128
33
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
34
+ SAFE_HEAD_DIM : tl.constexpr = True
35
+ USE_TMA : tl.constexpr = False
36
+ BLOCK_M : tl.constexpr = 128
37
+ BLOCK_N : tl.constexpr = 64
38
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
39
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
40
+ INDEX_DTYPE : tl.constexpr = tl.int32
41
+ Q = arg_Q
42
+ K = arg_K
43
+ V = arg_V
44
+ LSE = arg_LSE
45
+ MAX = arg_MAX
46
+ KV_NUM_BLKS = arg_KV_NUM_BLKS
47
+ KV_IDX = arg_KV_IDX
48
+ FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
49
+ FULL_KV_IDX = arg_FULL_KV_IDX
50
+
51
+ # Sub notation for this kernel:
52
+ #
53
+ # Q: Query, K: Key, V: Value
54
+ # M: Number of queries, N: Number of keys/values, D: Model dimension
55
+ # QK_HEAD_DIM: The dimension of the query and key embeddings
56
+ # V_HEAD_DIM: The dimension of the value embeddings
57
+ # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head
58
+ # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
59
+ #
60
+ # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
61
+ # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
62
+ # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
63
+ # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
64
+ # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
65
+ #
66
+ # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad
67
+ #
68
+ # (Modifiable) Performance tuning options
69
+ # BLOCK_M: The thread block size across the seqlen dim of Q.
70
+ # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block.
71
+
72
+ # The below are kernel options that can be applied for certain score_mods,
73
+ # or involve a numerics vs. perf tradeoff
74
+ # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
75
+ # about 20% more numerical error, but slightly faster.
76
+ # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row
77
+ # is not masked out? If so, we can skip an extra safety check
78
+ # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are
79
+ # contiguous? If so, we don't need to do an indirect jump for every block
80
+
81
+ tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0)
82
+ tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0)
83
+
84
+ # Define strides of inputs
85
+ stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1
86
+ stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128*ks1, 128, 1
87
+ stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128*ks1, 128, 1
88
+
89
+ ZQ = 2
90
+ HQ = 32
91
+ Q_LEN = ks0
92
+ ZKV = 2
93
+ KV_LEN = ks1
94
+
95
+ MATMUL_PRECISION = Q.dtype.element_ty
96
+
97
+ q_start = tl.program_id(0).to(INDEX_DTYPE)
98
+ off_zq = tl.program_id(1).to(INDEX_DTYPE)
99
+ off_hq = tl.program_id(2).to(INDEX_DTYPE)
100
+
101
+ # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq.
102
+ # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0.
103
+ off_zkv = off_zq % ZKV
104
+ off_hkv = off_hq // GQA_SHARED_HEADS
105
+ off_g = off_hq % GQA_SHARED_HEADS
106
+
107
+ q_offset = off_zq * stride_qz + off_hq * stride_qh
108
+ k_offset = off_zkv * stride_kz + off_hkv * stride_kh
109
+ v_offset = off_zkv * stride_vz + off_hkv * stride_vh
110
+
111
+ Q = Q + q_offset
112
+ K = K + k_offset
113
+ V = V + v_offset
114
+
115
+ # Setting up the TMA descriptors for Q, K, V
116
+ desc_q = None
117
+ desc_k = None
118
+ desc_v = None
119
+
120
+ SPARSE_Z = 2
121
+ SPARSE_HQ = 1
122
+
123
+ sparse_idx_z = off_zq % SPARSE_Z
124
+ sparse_idx_hq = off_hq % SPARSE_HQ
125
+
126
+ SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M)
127
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
128
+
129
+ stride_kv_num_blks_h = ks2
130
+ stride_kv_idx_h = ks3*ks4
131
+ stride_kv_idx_m = ks4
132
+
133
+ # initialize pointer to m and l
134
+ m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
135
+ l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
136
+ acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
137
+
138
+ offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
139
+
140
+ # KV_IDX and KV_NUM_BLKS are always contiguous.
141
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq
142
+ sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE
143
+ sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950
144
+ offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
145
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
146
+ q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
147
+
148
+ # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
149
+ # We don't know anything "special" about these blocks, so we need to apply
150
+ # both score_mod and mask_mod to it
151
+ kv_indices = KV_IDX + sparse_kv_idx_offset
152
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
153
+ kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
154
+ block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
155
+
156
+
157
+ # K and V pointers will be passed directly to forward_inner
158
+
159
+ offs_n = kv_start + tl.arange(0, BLOCK_N)
160
+
161
+
162
+ acc, l_i, m_i = forward_inner(
163
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5,
164
+ q, K, V,
165
+ desc_k, desc_v, Q_LEN, KV_LEN,
166
+ acc, l_i, m_i,
167
+ off_zq, off_hq, offs_m[:, None], offs_n[None, :],
168
+ kv_start,
169
+ kv_indices, kv_num_blocks,
170
+ 0, block_n_end,
171
+ MATMUL_PRECISION,
172
+ stride_kk, stride_kn, stride_vn, stride_vk,
173
+ IS_FULL_BLOCKS=False,
174
+ )
175
+
176
+ # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
177
+ # We know these blocks are guaranteed to be "full", so we don't need to
178
+ # apply mask_mod to them - only score_mod
179
+ if HAS_FULL_BLOCKS:
180
+ # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
181
+ kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
182
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
183
+ kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
184
+ block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
185
+ # K and V pointers will be passed directly to forward_inner
186
+ offs_n = kv_start + tl.arange(0, BLOCK_N)
187
+
188
+ acc, l_i, m_i = forward_inner(
189
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5,
190
+ q, K, V,
191
+ desc_k, desc_v, Q_LEN, KV_LEN,
192
+ acc, l_i, m_i,
193
+ off_zq, off_hq, offs_m[:, None], offs_n[None, :],
194
+ kv_start,
195
+ kv_indices, kv_num_blocks,
196
+ 0, block_n_end,
197
+ MATMUL_PRECISION,
198
+ stride_kk, stride_kn, stride_vn, stride_vk,
199
+ IS_FULL_BLOCKS=True,
200
+ )
201
+
202
+
203
+ # [Note] Handle fully masked out rows:
204
+ # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf.
205
+ # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step
206
+ l_i = tl.where(l_i == 0.0, 1, l_i)
207
+
208
+ acc = acc / l_i[:, None]
209
+ idx_zq = tl.program_id(1).to(INDEX_DTYPE)
210
+ idx_hq = tl.program_id(2).to(INDEX_DTYPE)
211
+ idx_m = offs_m[:, None].to(INDEX_DTYPE)
212
+ idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE)
213
+
214
+ mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM)
215
+
216
+ tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED])
217
+ xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0
218
+ tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m + 4096*idx_zq*ks0, acc.shape)), acc, mask)
219
+
220
+ if OUTPUT_LOGSUMEXP:
221
+ off_hz = off_zq * HQ + off_hq
222
+ l_ptrs = LSE + off_hz * Q_LEN + offs_m
223
+ lse = m_i + tl.math.log2(l_i)
224
+ if IS_DIVISIBLE:
225
+ tl.store(l_ptrs, lse)
226
+ else:
227
+ tl.store(l_ptrs, lse, mask=offs_m < Q_LEN)
228
+
229
+ if OUTPUT_MAX:
230
+ off_hz = off_zq * HQ + off_hq
231
+ max_ptrs = MAX + off_hz * Q_LEN + offs_m
232
+ if IS_DIVISIBLE:
233
+ tl.store(max_ptrs, m_i)
234
+ else:
235
+ tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN)
236
+
237
+
238
+ # Utility triton funcs
239
+ @triton.jit
240
+ def get_offset_for_next_block(
241
+ loop_iter, col_indices, total_blocks,
242
+ SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
243
+ BLOCKS_ARE_CONTIGUOUS: tl.constexpr
244
+ ):
245
+ if BLOCKS_ARE_CONTIGUOUS:
246
+ return BLOCK
247
+ cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
248
+ cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
249
+ next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
250
+ needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
251
+ jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
252
+ offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
253
+ return offset
254
+
255
+ @triton.jit
256
+ def get_bounded_indices(indices, max_len=None):
257
+ return indices % max_len if max_len is not None else indices
258
+
259
+ @triton.jit
260
+ def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
261
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
262
+ return tl.load(block_ptr)
263
+ elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
264
+ return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
265
+ elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
266
+ return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
267
+ else:
268
+ return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
269
+
270
+ @triton.jit
271
+ def load_checked_2d(
272
+ ptr,
273
+ offs_m,
274
+ offs_n,
275
+ stride_m,
276
+ stride_n,
277
+ IS_DIVISIBLE_M: tl.constexpr,
278
+ IS_DIVISIBLE_N: tl.constexpr,
279
+ M_LEN: tl.constexpr,
280
+ N_LEN: tl.constexpr,
281
+ ):
282
+ # Calculate final pointer if strides are provided
283
+ if stride_m is not None and stride_n is not None:
284
+ ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
285
+
286
+ # Handle all masking cases
287
+ if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
288
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
289
+ elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
290
+ return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
291
+ elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
292
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
293
+ else: # Both divisible
294
+ return tl.load(ptr)
295
+
296
+
297
+ # Common Imports
298
+ @triton.jit
299
+ def forward_block_mn(
300
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5,
301
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
302
+ # accumulated values
303
+ acc, l_i, m_i,
304
+ # Offsets
305
+ off_z, off_h, offs_m, offs_n,
306
+ # Offsets needed for TMA loads
307
+ kv_start,
308
+ kv_offset,
309
+ MATMUL_PRECISION, RCP_LN2,
310
+ # Strides for K and V
311
+ stride_kk, stride_kn, stride_vn, stride_vk,
312
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
313
+
314
+ ):
315
+ # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
316
+ PRESCALE_QK : tl.constexpr = False
317
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
318
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
319
+ WRITE_DQ : tl.constexpr = True
320
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
321
+ OUTPUT_MAX : tl.constexpr = False
322
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
323
+ IS_DIVISIBLE : tl.constexpr = False
324
+ SM_SCALE : tl.constexpr = 0.08838834764831843
325
+ GQA_SHARED_HEADS : tl.constexpr = 4
326
+ HAS_FULL_BLOCKS : tl.constexpr = True
327
+ QK_HEAD_DIM : tl.constexpr = 128
328
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
329
+ V_HEAD_DIM : tl.constexpr = 128
330
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
331
+ SAFE_HEAD_DIM : tl.constexpr = True
332
+ USE_TMA : tl.constexpr = False
333
+ BLOCK_M : tl.constexpr = 128
334
+ BLOCK_N : tl.constexpr = 64
335
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
336
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
337
+ INDEX_DTYPE : tl.constexpr = tl.int32
338
+
339
+
340
+ # -- load k --
341
+ # NB reversed order to since K is transposed
342
+ kv_base_offset = kv_start + kv_offset
343
+
344
+ # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N]
345
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
346
+ offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N)
347
+ k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
348
+
349
+ k = tl.trans(k)
350
+ # -- compute qk ---
351
+ qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2.
352
+ if not PRESCALE_QK:
353
+ qk *= SM_SCALE
354
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
355
+ # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements,
356
+ # which is larger than the actual number of elements. To avoid access memory out of bound,
357
+ # we need to mask out the elements that are out of Q_LEN & KV_LEN.
358
+ m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None)
359
+ n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None)
360
+
361
+ tmp0 = (qk)
362
+ post_mod_scores = tmp0
363
+
364
+
365
+ if CHECK_BLOCK_BOUNDARY:
366
+ # Mask out the elements that are out of the KV_LEN for non divisible seqlen.
367
+ post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf"))
368
+
369
+ if not IS_FULL_BLOCKS:
370
+ tmp1 = tl.full([1], False, tl.int1)
371
+ tmp2 = (m)
372
+ tmp3 = (n)
373
+ tmp4 = tmp2 >= tmp3
374
+ tmp5 = tmp3.to(tl.int64)
375
+ tmp6 = (off_z)
376
+ tmp7 = tl.load(in_ptr9 + tmp6)
377
+ tmp8 = tmp5 < tmp7
378
+ tmp9 = tmp2.to(tl.int64)
379
+ tmp10 = tmp9 < tmp7
380
+ tmp11 = tmp8 & tmp10
381
+ tmp12 = tmp4 & tmp11
382
+ tmp13 = tmp1 | tmp12
383
+ tmp14 = ks5
384
+ tmp15 = tmp3 >= tmp14
385
+ tmp16 = (tmp3 % tmp14)
386
+ tmp17 = tl.full([1], 0, tl.int32)
387
+ tmp18 = tmp16 != tmp17
388
+ tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0
389
+ tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0
390
+ tmp21 = tmp19 != tmp20
391
+ tmp22 = tmp18 & tmp21
392
+ tmp23 = tmp16 + tmp14
393
+ tmp24 = tl.where(tmp22, tmp23, tmp16)
394
+ tmp25 = tmp24.to(tl.int64)
395
+ tmp26 = tmp25 < tmp7
396
+ tmp27 = tmp15 & tmp26
397
+ tmp28 = tmp3 - tmp2
398
+ tmp29 = (tmp28 % tmp14)
399
+ tmp30 = tmp29 != tmp17
400
+ tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0
401
+ tmp32 = tmp31 != tmp20
402
+ tmp33 = tmp30 & tmp32
403
+ tmp34 = tmp29 + tmp14
404
+ tmp35 = tl.where(tmp33, tmp34, tmp29)
405
+ tmp36 = tmp35 == tmp17
406
+ tmp37 = tmp27 & tmp36
407
+ tmp38 = tmp13 | tmp37
408
+ mask_mod_output = tmp38
409
+
410
+
411
+ if CHECK_BLOCK_BOUNDARY:
412
+ mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False)
413
+ # apply mask for partially unmasked blocks
414
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
415
+
416
+ if not PRESCALE_QK:
417
+ post_mod_scores *= RCP_LN2
418
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
419
+
420
+ # -- compute scaling constant ---
421
+ m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1))
422
+ if not ROWS_GUARANTEED_SAFE:
423
+ masked_out_rows = (m_ij == float("-inf"))
424
+ m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
425
+ else:
426
+ m_ij_masked = m_ij
427
+
428
+ alpha = tl.math.exp2(m_i - m_ij_masked)
429
+ p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None])
430
+
431
+ # NB: l_i update is pulled up here since it's a bit faster
432
+ # NB: For headdim=256, it's faster to move it back down to after m_i =
433
+ # m_ij
434
+ l_i = l_i * alpha + tl.sum(p, 1)
435
+ # # -- scale and update acc --
436
+ acc = acc * alpha[:, None]
437
+ # Calculate offsets for V loading - reuse kv_base_offset from K loading
438
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
439
+ v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
440
+ acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION)
441
+
442
+ # -- update m_i
443
+ m_i = m_ij
444
+
445
+ return acc, l_i, m_i
446
+
447
+ @triton.jit
448
+ def forward_inner(
449
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5,
450
+ q, K, V,
451
+ desc_k, desc_v, Q_LEN, KV_LEN,
452
+ # accumulated values
453
+ acc, l_i, m_i,
454
+ # Offsets used as inputs to score_mod & mask_mod
455
+ # of size [BLOCK_M, BLOCK_N] or scalar.
456
+ off_z, off_h, offs_m, offs_n,
457
+ # Offsets needed for TMA loads
458
+ kv_start,
459
+ # blocksparse data
460
+ kv_indices, kv_num_blocks,
461
+ # start kv and end kv block
462
+ block_n_start, block_n_end,
463
+ MATMUL_PRECISION,
464
+ # Strides for K and V
465
+ stride_kk, stride_kn, stride_vn, stride_vk,
466
+ IS_FULL_BLOCKS,
467
+ ):
468
+ # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
469
+ PRESCALE_QK : tl.constexpr = False
470
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
471
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
472
+ WRITE_DQ : tl.constexpr = True
473
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
474
+ OUTPUT_MAX : tl.constexpr = False
475
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
476
+ IS_DIVISIBLE : tl.constexpr = False
477
+ SM_SCALE : tl.constexpr = 0.08838834764831843
478
+ GQA_SHARED_HEADS : tl.constexpr = 4
479
+ HAS_FULL_BLOCKS : tl.constexpr = True
480
+ QK_HEAD_DIM : tl.constexpr = 128
481
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
482
+ V_HEAD_DIM : tl.constexpr = 128
483
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
484
+ SAFE_HEAD_DIM : tl.constexpr = True
485
+ USE_TMA : tl.constexpr = False
486
+ BLOCK_M : tl.constexpr = 128
487
+ BLOCK_N : tl.constexpr = 64
488
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
489
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
490
+ INDEX_DTYPE : tl.constexpr = tl.int32
491
+
492
+
493
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
494
+ RCP_LN2: tl.constexpr = 1.44269504
495
+
496
+ if PRESCALE_QK:
497
+ q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
498
+
499
+ kv_offset = 0
500
+
501
+ # loop over k, v and update accumulator until block_n_end
502
+ for start_n in range(block_n_start, block_n_end):
503
+ # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention.
504
+ if IS_DIVISIBLE:
505
+ acc, l_i, m_i = forward_block_mn(
506
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5,
507
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
508
+ # accumulated values
509
+ acc, l_i, m_i,
510
+ # Offsets
511
+ off_z, off_h, offs_m, offs_n,
512
+ # Offsets needed for TMA loads
513
+ kv_start,
514
+ kv_offset,
515
+ MATMUL_PRECISION, RCP_LN2,
516
+ # Strides for K and V
517
+ stride_kk, stride_kn, stride_vn, stride_vk,
518
+ IS_FULL_BLOCKS,
519
+ )
520
+ else:
521
+ # Benchmark shows even we applied mod & mask to each block for non divisible seqlen,
522
+ # it's on par or slightly faster than only applying to the last block in fwd.
523
+ # However, we choose different strategy for bwd, where we only apply mod & mask
524
+ # to the last block because it's faster a lot.
525
+ acc, l_i, m_i = forward_block_mn(
526
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5,
527
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
528
+ # accumulated values
529
+ acc, l_i, m_i,
530
+ # Offsets
531
+ off_z, off_h, offs_m, offs_n,
532
+ # Offsets needed for TMA loads
533
+ kv_start,
534
+ kv_offset,
535
+ MATMUL_PRECISION, RCP_LN2,
536
+ # Strides for K and V
537
+ stride_kk, stride_kn, stride_vn, stride_vk,
538
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
539
+ )
540
+
541
+
542
+
543
+ offset = get_offset_for_next_block(
544
+ start_n, kv_indices, kv_num_blocks,
545
+ SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS
546
+ )
547
+
548
+ offs_n = offs_n + offset
549
+ kv_offset += offset
550
+
551
+
552
+ return acc, l_i, m_i
SpecForge-ext/cache/compiled_kernels/4y/c4yua3qi2b3xk6rn6ls5sdrsrpavp4zes7z62ki32y5ijfhzw4bb.py ADDED
@@ -0,0 +1,552 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+
9
+ @triton_heuristics.template(
10
+
11
+ num_stages=3,
12
+ num_warps=8,
13
+ triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'in_ptr9': '*i64', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
15
+
16
+ )
17
+ @triton.jit
18
+ def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0):
19
+ PRESCALE_QK : tl.constexpr = False
20
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
21
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
22
+ WRITE_DQ : tl.constexpr = True
23
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
24
+ OUTPUT_MAX : tl.constexpr = False
25
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
26
+ IS_DIVISIBLE : tl.constexpr = True
27
+ SM_SCALE : tl.constexpr = 0.08838834764831843
28
+ GQA_SHARED_HEADS : tl.constexpr = 4
29
+ HAS_FULL_BLOCKS : tl.constexpr = True
30
+ QK_HEAD_DIM : tl.constexpr = 128
31
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
32
+ V_HEAD_DIM : tl.constexpr = 128
33
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
34
+ SAFE_HEAD_DIM : tl.constexpr = True
35
+ USE_TMA : tl.constexpr = False
36
+ BLOCK_M : tl.constexpr = 128
37
+ BLOCK_N : tl.constexpr = 64
38
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
39
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
40
+ INDEX_DTYPE : tl.constexpr = tl.int32
41
+ Q = arg_Q
42
+ K = arg_K
43
+ V = arg_V
44
+ LSE = arg_LSE
45
+ MAX = arg_MAX
46
+ KV_NUM_BLKS = arg_KV_NUM_BLKS
47
+ KV_IDX = arg_KV_IDX
48
+ FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
49
+ FULL_KV_IDX = arg_FULL_KV_IDX
50
+
51
+ # Sub notation for this kernel:
52
+ #
53
+ # Q: Query, K: Key, V: Value
54
+ # M: Number of queries, N: Number of keys/values, D: Model dimension
55
+ # QK_HEAD_DIM: The dimension of the query and key embeddings
56
+ # V_HEAD_DIM: The dimension of the value embeddings
57
+ # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head
58
+ # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
59
+ #
60
+ # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
61
+ # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
62
+ # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
63
+ # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
64
+ # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
65
+ #
66
+ # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad
67
+ #
68
+ # (Modifiable) Performance tuning options
69
+ # BLOCK_M: The thread block size across the seqlen dim of Q.
70
+ # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block.
71
+
72
+ # The below are kernel options that can be applied for certain score_mods,
73
+ # or involve a numerics vs. perf tradeoff
74
+ # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
75
+ # about 20% more numerical error, but slightly faster.
76
+ # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row
77
+ # is not masked out? If so, we can skip an extra safety check
78
+ # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are
79
+ # contiguous? If so, we don't need to do an indirect jump for every block
80
+
81
+ tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0)
82
+ tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0)
83
+
84
+ # Define strides of inputs
85
+ stride_qz, stride_qh, stride_qm, stride_qk = 8388608, 128, 4096, 1
86
+ stride_kz, stride_kh, stride_kn, stride_kk = 2097152, 262144, 128, 1
87
+ stride_vz, stride_vh, stride_vn, stride_vk = 2097152, 262144, 128, 1
88
+
89
+ ZQ = 2
90
+ HQ = 32
91
+ Q_LEN = 2048
92
+ ZKV = 2
93
+ KV_LEN = 2048
94
+
95
+ MATMUL_PRECISION = Q.dtype.element_ty
96
+
97
+ q_start = tl.program_id(0).to(INDEX_DTYPE)
98
+ off_zq = tl.program_id(1).to(INDEX_DTYPE)
99
+ off_hq = tl.program_id(2).to(INDEX_DTYPE)
100
+
101
+ # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq.
102
+ # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0.
103
+ off_zkv = off_zq % ZKV
104
+ off_hkv = off_hq // GQA_SHARED_HEADS
105
+ off_g = off_hq % GQA_SHARED_HEADS
106
+
107
+ q_offset = off_zq * stride_qz + off_hq * stride_qh
108
+ k_offset = off_zkv * stride_kz + off_hkv * stride_kh
109
+ v_offset = off_zkv * stride_vz + off_hkv * stride_vh
110
+
111
+ Q = Q + q_offset
112
+ K = K + k_offset
113
+ V = V + v_offset
114
+
115
+ # Setting up the TMA descriptors for Q, K, V
116
+ desc_q = None
117
+ desc_k = None
118
+ desc_v = None
119
+
120
+ SPARSE_Z = 2
121
+ SPARSE_HQ = 1
122
+
123
+ sparse_idx_z = off_zq % SPARSE_Z
124
+ sparse_idx_hq = off_hq % SPARSE_HQ
125
+
126
+ SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M)
127
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
128
+
129
+ stride_kv_num_blks_h = 16
130
+ stride_kv_idx_h = 256
131
+ stride_kv_idx_m = 16
132
+
133
+ # initialize pointer to m and l
134
+ m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
135
+ l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
136
+ acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
137
+
138
+ offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
139
+
140
+ # KV_IDX and KV_NUM_BLKS are always contiguous.
141
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq
142
+ sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE
143
+ sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950
144
+ offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
145
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
146
+ q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
147
+
148
+ # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
149
+ # We don't know anything "special" about these blocks, so we need to apply
150
+ # both score_mod and mask_mod to it
151
+ kv_indices = KV_IDX + sparse_kv_idx_offset
152
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
153
+ kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
154
+ block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
155
+
156
+
157
+ # K and V pointers will be passed directly to forward_inner
158
+
159
+ offs_n = kv_start + tl.arange(0, BLOCK_N)
160
+
161
+
162
+ acc, l_i, m_i = forward_inner(
163
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0,
164
+ q, K, V,
165
+ desc_k, desc_v, Q_LEN, KV_LEN,
166
+ acc, l_i, m_i,
167
+ off_zq, off_hq, offs_m[:, None], offs_n[None, :],
168
+ kv_start,
169
+ kv_indices, kv_num_blocks,
170
+ 0, block_n_end,
171
+ MATMUL_PRECISION,
172
+ stride_kk, stride_kn, stride_vn, stride_vk,
173
+ IS_FULL_BLOCKS=False,
174
+ )
175
+
176
+ # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
177
+ # We know these blocks are guaranteed to be "full", so we don't need to
178
+ # apply mask_mod to them - only score_mod
179
+ if HAS_FULL_BLOCKS:
180
+ # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
181
+ kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
182
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
183
+ kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
184
+ block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
185
+ # K and V pointers will be passed directly to forward_inner
186
+ offs_n = kv_start + tl.arange(0, BLOCK_N)
187
+
188
+ acc, l_i, m_i = forward_inner(
189
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0,
190
+ q, K, V,
191
+ desc_k, desc_v, Q_LEN, KV_LEN,
192
+ acc, l_i, m_i,
193
+ off_zq, off_hq, offs_m[:, None], offs_n[None, :],
194
+ kv_start,
195
+ kv_indices, kv_num_blocks,
196
+ 0, block_n_end,
197
+ MATMUL_PRECISION,
198
+ stride_kk, stride_kn, stride_vn, stride_vk,
199
+ IS_FULL_BLOCKS=True,
200
+ )
201
+
202
+
203
+ # [Note] Handle fully masked out rows:
204
+ # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf.
205
+ # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step
206
+ l_i = tl.where(l_i == 0.0, 1, l_i)
207
+
208
+ acc = acc / l_i[:, None]
209
+ idx_zq = tl.program_id(1).to(INDEX_DTYPE)
210
+ idx_hq = tl.program_id(2).to(INDEX_DTYPE)
211
+ idx_m = offs_m[:, None].to(INDEX_DTYPE)
212
+ idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE)
213
+
214
+ mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM)
215
+
216
+ tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED])
217
+ xindex = idx_d + 128*idx_m + 262144*idx_hq + 8388608*idx_zq
218
+ tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m + 8388608*idx_zq, acc.shape)), acc, mask)
219
+
220
+ if OUTPUT_LOGSUMEXP:
221
+ off_hz = off_zq * HQ + off_hq
222
+ l_ptrs = LSE + off_hz * Q_LEN + offs_m
223
+ lse = m_i + tl.math.log2(l_i)
224
+ if IS_DIVISIBLE:
225
+ tl.store(l_ptrs, lse)
226
+ else:
227
+ tl.store(l_ptrs, lse, mask=offs_m < Q_LEN)
228
+
229
+ if OUTPUT_MAX:
230
+ off_hz = off_zq * HQ + off_hq
231
+ max_ptrs = MAX + off_hz * Q_LEN + offs_m
232
+ if IS_DIVISIBLE:
233
+ tl.store(max_ptrs, m_i)
234
+ else:
235
+ tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN)
236
+
237
+
238
+ # Utility triton funcs
239
+ @triton.jit
240
+ def get_offset_for_next_block(
241
+ loop_iter, col_indices, total_blocks,
242
+ SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
243
+ BLOCKS_ARE_CONTIGUOUS: tl.constexpr
244
+ ):
245
+ if BLOCKS_ARE_CONTIGUOUS:
246
+ return BLOCK
247
+ cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
248
+ cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
249
+ next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
250
+ needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
251
+ jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
252
+ offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
253
+ return offset
254
+
255
+ @triton.jit
256
+ def get_bounded_indices(indices, max_len=None):
257
+ return indices % max_len if max_len is not None else indices
258
+
259
+ @triton.jit
260
+ def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
261
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
262
+ return tl.load(block_ptr)
263
+ elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
264
+ return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
265
+ elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
266
+ return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
267
+ else:
268
+ return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
269
+
270
+ @triton.jit
271
+ def load_checked_2d(
272
+ ptr,
273
+ offs_m,
274
+ offs_n,
275
+ stride_m,
276
+ stride_n,
277
+ IS_DIVISIBLE_M: tl.constexpr,
278
+ IS_DIVISIBLE_N: tl.constexpr,
279
+ M_LEN: tl.constexpr,
280
+ N_LEN: tl.constexpr,
281
+ ):
282
+ # Calculate final pointer if strides are provided
283
+ if stride_m is not None and stride_n is not None:
284
+ ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
285
+
286
+ # Handle all masking cases
287
+ if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
288
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
289
+ elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
290
+ return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
291
+ elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
292
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
293
+ else: # Both divisible
294
+ return tl.load(ptr)
295
+
296
+
297
+ # Common Imports
298
+ @triton.jit
299
+ def forward_block_mn(
300
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0,
301
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
302
+ # accumulated values
303
+ acc, l_i, m_i,
304
+ # Offsets
305
+ off_z, off_h, offs_m, offs_n,
306
+ # Offsets needed for TMA loads
307
+ kv_start,
308
+ kv_offset,
309
+ MATMUL_PRECISION, RCP_LN2,
310
+ # Strides for K and V
311
+ stride_kk, stride_kn, stride_vn, stride_vk,
312
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
313
+
314
+ ):
315
+ # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
316
+ PRESCALE_QK : tl.constexpr = False
317
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
318
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
319
+ WRITE_DQ : tl.constexpr = True
320
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
321
+ OUTPUT_MAX : tl.constexpr = False
322
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
323
+ IS_DIVISIBLE : tl.constexpr = True
324
+ SM_SCALE : tl.constexpr = 0.08838834764831843
325
+ GQA_SHARED_HEADS : tl.constexpr = 4
326
+ HAS_FULL_BLOCKS : tl.constexpr = True
327
+ QK_HEAD_DIM : tl.constexpr = 128
328
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
329
+ V_HEAD_DIM : tl.constexpr = 128
330
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
331
+ SAFE_HEAD_DIM : tl.constexpr = True
332
+ USE_TMA : tl.constexpr = False
333
+ BLOCK_M : tl.constexpr = 128
334
+ BLOCK_N : tl.constexpr = 64
335
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
336
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
337
+ INDEX_DTYPE : tl.constexpr = tl.int32
338
+
339
+
340
+ # -- load k --
341
+ # NB reversed order to since K is transposed
342
+ kv_base_offset = kv_start + kv_offset
343
+
344
+ # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N]
345
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
346
+ offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N)
347
+ k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
348
+
349
+ k = tl.trans(k)
350
+ # -- compute qk ---
351
+ qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2.
352
+ if not PRESCALE_QK:
353
+ qk *= SM_SCALE
354
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
355
+ # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements,
356
+ # which is larger than the actual number of elements. To avoid access memory out of bound,
357
+ # we need to mask out the elements that are out of Q_LEN & KV_LEN.
358
+ m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None)
359
+ n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None)
360
+
361
+ tmp0 = (qk)
362
+ post_mod_scores = tmp0
363
+
364
+
365
+ if CHECK_BLOCK_BOUNDARY:
366
+ # Mask out the elements that are out of the KV_LEN for non divisible seqlen.
367
+ post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf"))
368
+
369
+ if not IS_FULL_BLOCKS:
370
+ tmp1 = tl.full([1], False, tl.int1)
371
+ tmp2 = (m)
372
+ tmp3 = (n)
373
+ tmp4 = tmp2 >= tmp3
374
+ tmp5 = tmp3.to(tl.int64)
375
+ tmp6 = (off_z)
376
+ tmp7 = tl.load(in_ptr9 + tmp6)
377
+ tmp8 = tmp5 < tmp7
378
+ tmp9 = tmp2.to(tl.int64)
379
+ tmp10 = tmp9 < tmp7
380
+ tmp11 = tmp8 & tmp10
381
+ tmp12 = tmp4 & tmp11
382
+ tmp13 = tmp1 | tmp12
383
+ tmp14 = tl.full([1], 2048, tl.int32)
384
+ tmp15 = tmp3 >= tmp14
385
+ tmp16 = (tmp3 % tmp14)
386
+ tmp17 = tl.full([1], 0, tl.int32)
387
+ tmp18 = tmp16 != tmp17
388
+ tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0
389
+ tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0
390
+ tmp21 = tmp19 != tmp20
391
+ tmp22 = tmp18 & tmp21
392
+ tmp23 = tmp16 + tmp14
393
+ tmp24 = tl.where(tmp22, tmp23, tmp16)
394
+ tmp25 = tmp24.to(tl.int64)
395
+ tmp26 = tmp25 < tmp7
396
+ tmp27 = tmp15 & tmp26
397
+ tmp28 = tmp3 - tmp2
398
+ tmp29 = (tmp28 % tmp14)
399
+ tmp30 = tmp29 != tmp17
400
+ tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0
401
+ tmp32 = tmp31 != tmp20
402
+ tmp33 = tmp30 & tmp32
403
+ tmp34 = tmp29 + tmp14
404
+ tmp35 = tl.where(tmp33, tmp34, tmp29)
405
+ tmp36 = tmp35 == tmp17
406
+ tmp37 = tmp27 & tmp36
407
+ tmp38 = tmp13 | tmp37
408
+ mask_mod_output = tmp38
409
+
410
+
411
+ if CHECK_BLOCK_BOUNDARY:
412
+ mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False)
413
+ # apply mask for partially unmasked blocks
414
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
415
+
416
+ if not PRESCALE_QK:
417
+ post_mod_scores *= RCP_LN2
418
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
419
+
420
+ # -- compute scaling constant ---
421
+ m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1))
422
+ if not ROWS_GUARANTEED_SAFE:
423
+ masked_out_rows = (m_ij == float("-inf"))
424
+ m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
425
+ else:
426
+ m_ij_masked = m_ij
427
+
428
+ alpha = tl.math.exp2(m_i - m_ij_masked)
429
+ p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None])
430
+
431
+ # NB: l_i update is pulled up here since it's a bit faster
432
+ # NB: For headdim=256, it's faster to move it back down to after m_i =
433
+ # m_ij
434
+ l_i = l_i * alpha + tl.sum(p, 1)
435
+ # # -- scale and update acc --
436
+ acc = acc * alpha[:, None]
437
+ # Calculate offsets for V loading - reuse kv_base_offset from K loading
438
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
439
+ v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
440
+ acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION)
441
+
442
+ # -- update m_i
443
+ m_i = m_ij
444
+
445
+ return acc, l_i, m_i
446
+
447
+ @triton.jit
448
+ def forward_inner(
449
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0,
450
+ q, K, V,
451
+ desc_k, desc_v, Q_LEN, KV_LEN,
452
+ # accumulated values
453
+ acc, l_i, m_i,
454
+ # Offsets used as inputs to score_mod & mask_mod
455
+ # of size [BLOCK_M, BLOCK_N] or scalar.
456
+ off_z, off_h, offs_m, offs_n,
457
+ # Offsets needed for TMA loads
458
+ kv_start,
459
+ # blocksparse data
460
+ kv_indices, kv_num_blocks,
461
+ # start kv and end kv block
462
+ block_n_start, block_n_end,
463
+ MATMUL_PRECISION,
464
+ # Strides for K and V
465
+ stride_kk, stride_kn, stride_vn, stride_vk,
466
+ IS_FULL_BLOCKS,
467
+ ):
468
+ # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
469
+ PRESCALE_QK : tl.constexpr = False
470
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
471
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
472
+ WRITE_DQ : tl.constexpr = True
473
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
474
+ OUTPUT_MAX : tl.constexpr = False
475
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
476
+ IS_DIVISIBLE : tl.constexpr = True
477
+ SM_SCALE : tl.constexpr = 0.08838834764831843
478
+ GQA_SHARED_HEADS : tl.constexpr = 4
479
+ HAS_FULL_BLOCKS : tl.constexpr = True
480
+ QK_HEAD_DIM : tl.constexpr = 128
481
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
482
+ V_HEAD_DIM : tl.constexpr = 128
483
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
484
+ SAFE_HEAD_DIM : tl.constexpr = True
485
+ USE_TMA : tl.constexpr = False
486
+ BLOCK_M : tl.constexpr = 128
487
+ BLOCK_N : tl.constexpr = 64
488
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
489
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
490
+ INDEX_DTYPE : tl.constexpr = tl.int32
491
+
492
+
493
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
494
+ RCP_LN2: tl.constexpr = 1.44269504
495
+
496
+ if PRESCALE_QK:
497
+ q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
498
+
499
+ kv_offset = 0
500
+
501
+ # loop over k, v and update accumulator until block_n_end
502
+ for start_n in range(block_n_start, block_n_end):
503
+ # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention.
504
+ if IS_DIVISIBLE:
505
+ acc, l_i, m_i = forward_block_mn(
506
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0,
507
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
508
+ # accumulated values
509
+ acc, l_i, m_i,
510
+ # Offsets
511
+ off_z, off_h, offs_m, offs_n,
512
+ # Offsets needed for TMA loads
513
+ kv_start,
514
+ kv_offset,
515
+ MATMUL_PRECISION, RCP_LN2,
516
+ # Strides for K and V
517
+ stride_kk, stride_kn, stride_vn, stride_vk,
518
+ IS_FULL_BLOCKS,
519
+ )
520
+ else:
521
+ # Benchmark shows even we applied mod & mask to each block for non divisible seqlen,
522
+ # it's on par or slightly faster than only applying to the last block in fwd.
523
+ # However, we choose different strategy for bwd, where we only apply mod & mask
524
+ # to the last block because it's faster a lot.
525
+ acc, l_i, m_i = forward_block_mn(
526
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0,
527
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
528
+ # accumulated values
529
+ acc, l_i, m_i,
530
+ # Offsets
531
+ off_z, off_h, offs_m, offs_n,
532
+ # Offsets needed for TMA loads
533
+ kv_start,
534
+ kv_offset,
535
+ MATMUL_PRECISION, RCP_LN2,
536
+ # Strides for K and V
537
+ stride_kk, stride_kn, stride_vn, stride_vk,
538
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
539
+ )
540
+
541
+
542
+
543
+ offset = get_offset_for_next_block(
544
+ start_n, kv_indices, kv_num_blocks,
545
+ SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS
546
+ )
547
+
548
+ offs_n = offs_n + offset
549
+ kv_offset += offset
550
+
551
+
552
+ return acc, l_i, m_i
SpecForge-ext/cache/compiled_kernels/6f/ba9cb84a5b5ef82fddf7d6be536aa0e0768988ffdd80996052da5fb28f5bfff3.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 128, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 11, "triton_cache_hash": "INOFCMBF4AOGTUSNRPBLV7E37E4P43AGG4323SKXUALONOEWOJUA"}
SpecForge-ext/cache/compiled_kernels/6n/c6njycmp52a4ww57u7ir3n6hwhaktjczce3zzyrhirlmhjbkrrhg.py ADDED
@@ -0,0 +1,693 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AOT ID: ['9_forward']
2
+ from ctypes import c_void_p, c_long, c_int
3
+ import torch
4
+ import math
5
+ import random
6
+ import os
7
+ import tempfile
8
+ from math import inf, nan
9
+ from cmath import nanj
10
+ from torch._inductor.hooks import run_intermediate_hooks
11
+ from torch._inductor.utils import maybe_profile
12
+ from torch._inductor.codegen.memory_planning import _align as align
13
+ from torch import device, empty_strided
14
+ from torch._inductor.async_compile import AsyncCompile
15
+ from torch._inductor.select_algorithm import extern_kernels
16
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
17
+ import triton
18
+ import triton.language as tl
19
+ from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
20
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
21
+
22
+ aten = torch.ops.aten
23
+ inductor_ops = torch.ops.inductor
24
+ _quantized = torch.ops._quantized
25
+ assert_size_stride = torch._C._dynamo.guards.assert_size_stride
26
+ assert_alignment = torch._C._dynamo.guards.assert_alignment
27
+ empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
28
+ empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
29
+ empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
30
+ empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
31
+ empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
32
+ reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
33
+ alloc_from_pool = torch.ops.inductor._alloc_from_pool
34
+ async_compile = AsyncCompile()
35
+ empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
36
+
37
+
38
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/s7/cs7qhjlt3qagwyyic2oiyost4mzjtbyquc6muggyzudwzg2u4vbt.py
39
+ # Topologically Sorted Source Nodes: [flex_attention], Original ATen: []
40
+ # Source node to ATen node mapping:
41
+ # flex_attention => flex_attention
42
+ # Graph fragment:
43
+ # %primals_1 : Tensor "bf16[2, 32, 2048, 128][8388608, 128, 4096, 1]cuda:6" = PlaceHolder[target=primals_1]
44
+ # %primals_3 : Tensor "bf16[2, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:6" = PlaceHolder[target=primals_3]
45
+ # %primals_5 : Tensor "bf16[2, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:6" = PlaceHolder[target=primals_5]
46
+ # %getitem_1 : Tensor "f32[2, 32, 2048][65536, 2048, 1]cuda:6" = PlaceHolder[target=getitem_1]
47
+ # %buf1 : Tensor "f32[2, 32, 2048][65536, 2048, 1]cuda:6" = PlaceHolder[target=buf1]
48
+ # %primals_9 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:6" = PlaceHolder[target=primals_9]
49
+ # %primals_7 : Tensor "i32[2, 1, 16, s72][16*s72, 16*s72, s72, 1]cuda:6" = PlaceHolder[target=primals_7]
50
+ # %primals_11 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:6" = PlaceHolder[target=primals_11]
51
+ # %primals_13 : Tensor "i32[2, 1, 16, s4][16*s4, 16*s4, s4, 1]cuda:6" = PlaceHolder[target=primals_13]
52
+ # %primals_10 : Tensor "i64[2][1]cuda:6" = PlaceHolder[target=primals_10]
53
+ # %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%primals_1, %primals_3, %primals_5, %sdpa_score0, (2048, %primals_8, %primals_9, %primals_7, %primals_11, %primals_13, %primals_15, %primals_17, %primals_19, %primals_21, 128, 128, %sdpa_mask0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_10,)), kwargs = {})
54
+ # return %getitem
55
+ triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', '''
56
+ import triton
57
+ import triton.language as tl
58
+
59
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
60
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
61
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
62
+
63
+ @triton_heuristics.template(
64
+
65
+ num_stages=3,
66
+ num_warps=8,
67
+ triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'in_ptr9': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]},
68
+ inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
69
+
70
+ )
71
+ @triton.jit
72
+ def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1):
73
+ PRESCALE_QK : tl.constexpr = False
74
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
75
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
76
+ WRITE_DQ : tl.constexpr = True
77
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
78
+ OUTPUT_MAX : tl.constexpr = False
79
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
80
+ IS_DIVISIBLE : tl.constexpr = False
81
+ SM_SCALE : tl.constexpr = 0.08838834764831843
82
+ GQA_SHARED_HEADS : tl.constexpr = 4
83
+ HAS_FULL_BLOCKS : tl.constexpr = True
84
+ QK_HEAD_DIM : tl.constexpr = 128
85
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
86
+ V_HEAD_DIM : tl.constexpr = 128
87
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
88
+ SAFE_HEAD_DIM : tl.constexpr = True
89
+ USE_TMA : tl.constexpr = False
90
+ BLOCK_M : tl.constexpr = 128
91
+ BLOCK_N : tl.constexpr = 64
92
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
93
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
94
+ INDEX_DTYPE : tl.constexpr = tl.int32
95
+ Q = arg_Q
96
+ K = arg_K
97
+ V = arg_V
98
+ LSE = arg_LSE
99
+ MAX = arg_MAX
100
+ KV_NUM_BLKS = arg_KV_NUM_BLKS
101
+ KV_IDX = arg_KV_IDX
102
+ FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
103
+ FULL_KV_IDX = arg_FULL_KV_IDX
104
+
105
+ # Sub notation for this kernel:
106
+ #
107
+ # Q: Query, K: Key, V: Value
108
+ # M: Number of queries, N: Number of keys/values, D: Model dimension
109
+ # QK_HEAD_DIM: The dimension of the query and key embeddings
110
+ # V_HEAD_DIM: The dimension of the value embeddings
111
+ # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head
112
+ # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
113
+ #
114
+ # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
115
+ # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
116
+ # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
117
+ # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
118
+ # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
119
+ #
120
+ # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad
121
+ #
122
+ # (Modifiable) Performance tuning options
123
+ # BLOCK_M: The thread block size across the seqlen dim of Q.
124
+ # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block.
125
+
126
+ # The below are kernel options that can be applied for certain score_mods,
127
+ # or involve a numerics vs. perf tradeoff
128
+ # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
129
+ # about 20% more numerical error, but slightly faster.
130
+ # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row
131
+ # is not masked out? If so, we can skip an extra safety check
132
+ # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are
133
+ # contiguous? If so, we don't need to do an indirect jump for every block
134
+
135
+ tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0)
136
+ tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0)
137
+
138
+ # Define strides of inputs
139
+ stride_qz, stride_qh, stride_qm, stride_qk = 8388608, 128, 4096, 1
140
+ stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks0, 128*ks0, 128, 1
141
+ stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks0, 128*ks0, 128, 1
142
+
143
+ ZQ = 2
144
+ HQ = 32
145
+ Q_LEN = 2048
146
+ ZKV = 2
147
+ KV_LEN = ks0
148
+
149
+ MATMUL_PRECISION = Q.dtype.element_ty
150
+
151
+ q_start = tl.program_id(0).to(INDEX_DTYPE)
152
+ off_zq = tl.program_id(1).to(INDEX_DTYPE)
153
+ off_hq = tl.program_id(2).to(INDEX_DTYPE)
154
+
155
+ # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq.
156
+ # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0.
157
+ off_zkv = off_zq % ZKV
158
+ off_hkv = off_hq // GQA_SHARED_HEADS
159
+ off_g = off_hq % GQA_SHARED_HEADS
160
+
161
+ q_offset = off_zq * stride_qz + off_hq * stride_qh
162
+ k_offset = off_zkv * stride_kz + off_hkv * stride_kh
163
+ v_offset = off_zkv * stride_vz + off_hkv * stride_vh
164
+
165
+ Q = Q + q_offset
166
+ K = K + k_offset
167
+ V = V + v_offset
168
+
169
+ # Setting up the TMA descriptors for Q, K, V
170
+ desc_q = None
171
+ desc_k = None
172
+ desc_v = None
173
+
174
+ SPARSE_Z = 2
175
+ SPARSE_HQ = 1
176
+
177
+ sparse_idx_z = off_zq % SPARSE_Z
178
+ sparse_idx_hq = off_hq % SPARSE_HQ
179
+
180
+ SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M)
181
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
182
+
183
+ stride_kv_num_blks_h = 16
184
+ stride_kv_idx_h = 16*ks1
185
+ stride_kv_idx_m = ks1
186
+
187
+ # initialize pointer to m and l
188
+ m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
189
+ l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
190
+ acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
191
+
192
+ offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
193
+
194
+ # KV_IDX and KV_NUM_BLKS are always contiguous.
195
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq
196
+ sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE
197
+ sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950
198
+ offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
199
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
200
+ q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
201
+
202
+ # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
203
+ # We don't know anything "special" about these blocks, so we need to apply
204
+ # both score_mod and mask_mod to it
205
+ kv_indices = KV_IDX + sparse_kv_idx_offset
206
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
207
+ kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
208
+ block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
209
+
210
+
211
+ # K and V pointers will be passed directly to forward_inner
212
+
213
+ offs_n = kv_start + tl.arange(0, BLOCK_N)
214
+
215
+
216
+ acc, l_i, m_i = forward_inner(
217
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1,
218
+ q, K, V,
219
+ desc_k, desc_v, Q_LEN, KV_LEN,
220
+ acc, l_i, m_i,
221
+ off_zq, off_hq, offs_m[:, None], offs_n[None, :],
222
+ kv_start,
223
+ kv_indices, kv_num_blocks,
224
+ 0, block_n_end,
225
+ MATMUL_PRECISION,
226
+ stride_kk, stride_kn, stride_vn, stride_vk,
227
+ IS_FULL_BLOCKS=False,
228
+ )
229
+
230
+ # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
231
+ # We know these blocks are guaranteed to be "full", so we don't need to
232
+ # apply mask_mod to them - only score_mod
233
+ if HAS_FULL_BLOCKS:
234
+ # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
235
+ kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
236
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
237
+ kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
238
+ block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
239
+ # K and V pointers will be passed directly to forward_inner
240
+ offs_n = kv_start + tl.arange(0, BLOCK_N)
241
+
242
+ acc, l_i, m_i = forward_inner(
243
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1,
244
+ q, K, V,
245
+ desc_k, desc_v, Q_LEN, KV_LEN,
246
+ acc, l_i, m_i,
247
+ off_zq, off_hq, offs_m[:, None], offs_n[None, :],
248
+ kv_start,
249
+ kv_indices, kv_num_blocks,
250
+ 0, block_n_end,
251
+ MATMUL_PRECISION,
252
+ stride_kk, stride_kn, stride_vn, stride_vk,
253
+ IS_FULL_BLOCKS=True,
254
+ )
255
+
256
+
257
+ # [Note] Handle fully masked out rows:
258
+ # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf.
259
+ # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step
260
+ l_i = tl.where(l_i == 0.0, 1, l_i)
261
+
262
+ acc = acc / l_i[:, None]
263
+ idx_zq = tl.program_id(1).to(INDEX_DTYPE)
264
+ idx_hq = tl.program_id(2).to(INDEX_DTYPE)
265
+ idx_m = offs_m[:, None].to(INDEX_DTYPE)
266
+ idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE)
267
+
268
+ mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM)
269
+
270
+ tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED])
271
+ xindex = idx_d + 128*idx_m + 262144*idx_hq + 8388608*idx_zq
272
+ tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m + 8388608*idx_zq, acc.shape)), acc, mask)
273
+
274
+ if OUTPUT_LOGSUMEXP:
275
+ off_hz = off_zq * HQ + off_hq
276
+ l_ptrs = LSE + off_hz * Q_LEN + offs_m
277
+ lse = m_i + tl.math.log2(l_i)
278
+ if IS_DIVISIBLE:
279
+ tl.store(l_ptrs, lse)
280
+ else:
281
+ tl.store(l_ptrs, lse, mask=offs_m < Q_LEN)
282
+
283
+ if OUTPUT_MAX:
284
+ off_hz = off_zq * HQ + off_hq
285
+ max_ptrs = MAX + off_hz * Q_LEN + offs_m
286
+ if IS_DIVISIBLE:
287
+ tl.store(max_ptrs, m_i)
288
+ else:
289
+ tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN)
290
+
291
+
292
+ # Utility triton funcs
293
+ @triton.jit
294
+ def get_offset_for_next_block(
295
+ loop_iter, col_indices, total_blocks,
296
+ SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
297
+ BLOCKS_ARE_CONTIGUOUS: tl.constexpr
298
+ ):
299
+ if BLOCKS_ARE_CONTIGUOUS:
300
+ return BLOCK
301
+ cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
302
+ cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
303
+ next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
304
+ needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
305
+ jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
306
+ offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
307
+ return offset
308
+
309
+ @triton.jit
310
+ def get_bounded_indices(indices, max_len=None):
311
+ return indices % max_len if max_len is not None else indices
312
+
313
+ @triton.jit
314
+ def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
315
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
316
+ return tl.load(block_ptr)
317
+ elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
318
+ return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
319
+ elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
320
+ return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
321
+ else:
322
+ return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
323
+
324
+ @triton.jit
325
+ def load_checked_2d(
326
+ ptr,
327
+ offs_m,
328
+ offs_n,
329
+ stride_m,
330
+ stride_n,
331
+ IS_DIVISIBLE_M: tl.constexpr,
332
+ IS_DIVISIBLE_N: tl.constexpr,
333
+ M_LEN: tl.constexpr,
334
+ N_LEN: tl.constexpr,
335
+ ):
336
+ # Calculate final pointer if strides are provided
337
+ if stride_m is not None and stride_n is not None:
338
+ ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
339
+
340
+ # Handle all masking cases
341
+ if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
342
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
343
+ elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
344
+ return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
345
+ elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
346
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
347
+ else: # Both divisible
348
+ return tl.load(ptr)
349
+
350
+
351
+ # Common Imports
352
+ @triton.jit
353
+ def forward_block_mn(
354
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1,
355
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
356
+ # accumulated values
357
+ acc, l_i, m_i,
358
+ # Offsets
359
+ off_z, off_h, offs_m, offs_n,
360
+ # Offsets needed for TMA loads
361
+ kv_start,
362
+ kv_offset,
363
+ MATMUL_PRECISION, RCP_LN2,
364
+ # Strides for K and V
365
+ stride_kk, stride_kn, stride_vn, stride_vk,
366
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
367
+
368
+ ):
369
+ # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
370
+ PRESCALE_QK : tl.constexpr = False
371
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
372
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
373
+ WRITE_DQ : tl.constexpr = True
374
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
375
+ OUTPUT_MAX : tl.constexpr = False
376
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
377
+ IS_DIVISIBLE : tl.constexpr = False
378
+ SM_SCALE : tl.constexpr = 0.08838834764831843
379
+ GQA_SHARED_HEADS : tl.constexpr = 4
380
+ HAS_FULL_BLOCKS : tl.constexpr = True
381
+ QK_HEAD_DIM : tl.constexpr = 128
382
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
383
+ V_HEAD_DIM : tl.constexpr = 128
384
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
385
+ SAFE_HEAD_DIM : tl.constexpr = True
386
+ USE_TMA : tl.constexpr = False
387
+ BLOCK_M : tl.constexpr = 128
388
+ BLOCK_N : tl.constexpr = 64
389
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
390
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
391
+ INDEX_DTYPE : tl.constexpr = tl.int32
392
+
393
+
394
+ # -- load k --
395
+ # NB reversed order to since K is transposed
396
+ kv_base_offset = kv_start + kv_offset
397
+
398
+ # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N]
399
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
400
+ offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N)
401
+ k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
402
+
403
+ k = tl.trans(k)
404
+ # -- compute qk ---
405
+ qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2.
406
+ if not PRESCALE_QK:
407
+ qk *= SM_SCALE
408
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
409
+ # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements,
410
+ # which is larger than the actual number of elements. To avoid access memory out of bound,
411
+ # we need to mask out the elements that are out of Q_LEN & KV_LEN.
412
+ m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None)
413
+ n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None)
414
+
415
+ tmp0 = (qk)
416
+ post_mod_scores = tmp0
417
+
418
+
419
+ if CHECK_BLOCK_BOUNDARY:
420
+ # Mask out the elements that are out of the KV_LEN for non divisible seqlen.
421
+ post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf"))
422
+
423
+ if not IS_FULL_BLOCKS:
424
+ tmp1 = tl.full([1], False, tl.int1)
425
+ tmp2 = (m)
426
+ tmp3 = (n)
427
+ tmp4 = tmp2 >= tmp3
428
+ tmp5 = tmp3.to(tl.int64)
429
+ tmp6 = (off_z)
430
+ tmp7 = tl.load(in_ptr9 + tmp6)
431
+ tmp8 = tmp5 < tmp7
432
+ tmp9 = tmp2.to(tl.int64)
433
+ tmp10 = tmp9 < tmp7
434
+ tmp11 = tmp8 & tmp10
435
+ tmp12 = tmp4 & tmp11
436
+ tmp13 = tmp1 | tmp12
437
+ tmp14 = tl.full([1], 2048, tl.int32)
438
+ tmp15 = tmp3 >= tmp14
439
+ tmp16 = (tmp3 % tmp14)
440
+ tmp17 = tl.full([1], 0, tl.int32)
441
+ tmp18 = tmp16 != tmp17
442
+ tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0
443
+ tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0
444
+ tmp21 = tmp19 != tmp20
445
+ tmp22 = tmp18 & tmp21
446
+ tmp23 = tmp16 + tmp14
447
+ tmp24 = tl.where(tmp22, tmp23, tmp16)
448
+ tmp25 = tmp24.to(tl.int64)
449
+ tmp26 = tmp25 < tmp7
450
+ tmp27 = tmp15 & tmp26
451
+ tmp28 = tmp3 - tmp2
452
+ tmp29 = (tmp28 % tmp14)
453
+ tmp30 = tmp29 != tmp17
454
+ tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0
455
+ tmp32 = tmp31 != tmp20
456
+ tmp33 = tmp30 & tmp32
457
+ tmp34 = tmp29 + tmp14
458
+ tmp35 = tl.where(tmp33, tmp34, tmp29)
459
+ tmp36 = tmp35 == tmp17
460
+ tmp37 = tmp27 & tmp36
461
+ tmp38 = tmp13 | tmp37
462
+ mask_mod_output = tmp38
463
+
464
+
465
+ if CHECK_BLOCK_BOUNDARY:
466
+ mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False)
467
+ # apply mask for partially unmasked blocks
468
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
469
+
470
+ if not PRESCALE_QK:
471
+ post_mod_scores *= RCP_LN2
472
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
473
+
474
+ # -- compute scaling constant ---
475
+ m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1))
476
+ if not ROWS_GUARANTEED_SAFE:
477
+ masked_out_rows = (m_ij == float("-inf"))
478
+ m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
479
+ else:
480
+ m_ij_masked = m_ij
481
+
482
+ alpha = tl.math.exp2(m_i - m_ij_masked)
483
+ p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None])
484
+
485
+ # NB: l_i update is pulled up here since it's a bit faster
486
+ # NB: For headdim=256, it's faster to move it back down to after m_i =
487
+ # m_ij
488
+ l_i = l_i * alpha + tl.sum(p, 1)
489
+ # # -- scale and update acc --
490
+ acc = acc * alpha[:, None]
491
+ # Calculate offsets for V loading - reuse kv_base_offset from K loading
492
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
493
+ v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
494
+ acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION)
495
+
496
+ # -- update m_i
497
+ m_i = m_ij
498
+
499
+ return acc, l_i, m_i
500
+
501
+ @triton.jit
502
+ def forward_inner(
503
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1,
504
+ q, K, V,
505
+ desc_k, desc_v, Q_LEN, KV_LEN,
506
+ # accumulated values
507
+ acc, l_i, m_i,
508
+ # Offsets used as inputs to score_mod & mask_mod
509
+ # of size [BLOCK_M, BLOCK_N] or scalar.
510
+ off_z, off_h, offs_m, offs_n,
511
+ # Offsets needed for TMA loads
512
+ kv_start,
513
+ # blocksparse data
514
+ kv_indices, kv_num_blocks,
515
+ # start kv and end kv block
516
+ block_n_start, block_n_end,
517
+ MATMUL_PRECISION,
518
+ # Strides for K and V
519
+ stride_kk, stride_kn, stride_vn, stride_vk,
520
+ IS_FULL_BLOCKS,
521
+ ):
522
+ # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
523
+ PRESCALE_QK : tl.constexpr = False
524
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
525
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
526
+ WRITE_DQ : tl.constexpr = True
527
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
528
+ OUTPUT_MAX : tl.constexpr = False
529
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
530
+ IS_DIVISIBLE : tl.constexpr = False
531
+ SM_SCALE : tl.constexpr = 0.08838834764831843
532
+ GQA_SHARED_HEADS : tl.constexpr = 4
533
+ HAS_FULL_BLOCKS : tl.constexpr = True
534
+ QK_HEAD_DIM : tl.constexpr = 128
535
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
536
+ V_HEAD_DIM : tl.constexpr = 128
537
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
538
+ SAFE_HEAD_DIM : tl.constexpr = True
539
+ USE_TMA : tl.constexpr = False
540
+ BLOCK_M : tl.constexpr = 128
541
+ BLOCK_N : tl.constexpr = 64
542
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
543
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
544
+ INDEX_DTYPE : tl.constexpr = tl.int32
545
+
546
+
547
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
548
+ RCP_LN2: tl.constexpr = 1.44269504
549
+
550
+ if PRESCALE_QK:
551
+ q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
552
+
553
+ kv_offset = 0
554
+
555
+ # loop over k, v and update accumulator until block_n_end
556
+ for start_n in range(block_n_start, block_n_end):
557
+ # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention.
558
+ if IS_DIVISIBLE:
559
+ acc, l_i, m_i = forward_block_mn(
560
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1,
561
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
562
+ # accumulated values
563
+ acc, l_i, m_i,
564
+ # Offsets
565
+ off_z, off_h, offs_m, offs_n,
566
+ # Offsets needed for TMA loads
567
+ kv_start,
568
+ kv_offset,
569
+ MATMUL_PRECISION, RCP_LN2,
570
+ # Strides for K and V
571
+ stride_kk, stride_kn, stride_vn, stride_vk,
572
+ IS_FULL_BLOCKS,
573
+ )
574
+ else:
575
+ # Benchmark shows even we applied mod & mask to each block for non divisible seqlen,
576
+ # it's on par or slightly faster than only applying to the last block in fwd.
577
+ # However, we choose different strategy for bwd, where we only apply mod & mask
578
+ # to the last block because it's faster a lot.
579
+ acc, l_i, m_i = forward_block_mn(
580
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1,
581
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
582
+ # accumulated values
583
+ acc, l_i, m_i,
584
+ # Offsets
585
+ off_z, off_h, offs_m, offs_n,
586
+ # Offsets needed for TMA loads
587
+ kv_start,
588
+ kv_offset,
589
+ MATMUL_PRECISION, RCP_LN2,
590
+ # Strides for K and V
591
+ stride_kk, stride_kn, stride_vn, stride_vk,
592
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
593
+ )
594
+
595
+
596
+
597
+ offset = get_offset_for_next_block(
598
+ start_n, kv_indices, kv_num_blocks,
599
+ SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS
600
+ )
601
+
602
+ offs_n = offs_n + offset
603
+ kv_offset += offset
604
+
605
+
606
+ return acc, l_i, m_i
607
+ ''', device_str='cuda')
608
+
609
+
610
+ async_compile.wait(globals())
611
+ del async_compile
612
+
613
+ class Runner:
614
+ def __init__(self, partitions):
615
+ self.partitions = partitions
616
+
617
+ def recursively_apply_fns(self, fns):
618
+ new_callables = []
619
+ for fn, c in zip(fns, self.partitions):
620
+ new_callables.append(fn(c))
621
+ self.partitions = new_callables
622
+
623
+ def call(self, args):
624
+ primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21 = args
625
+ args.clear()
626
+ s0 = primals_2
627
+ s43 = primals_4
628
+ s72 = primals_6
629
+ s71 = primals_8
630
+ s4 = primals_12
631
+ s56 = primals_14
632
+ s84 = primals_16
633
+ s99 = primals_18
634
+ s6 = primals_20
635
+ assert_size_stride(primals_1, (2, 32, 2048, 128), (8388608, 128, 4096, 1))
636
+ assert_size_stride(primals_3, (2, 8, s0, 128), (1024*s0, 128*s0, 128, 1))
637
+ assert_size_stride(primals_5, (2, 8, s0, 128), (1024*s0, 128*s0, 128, 1))
638
+ assert_size_stride(primals_7, (2, 1, 16, s72), (16*s72, 16*s72, s72, 1))
639
+ assert_size_stride(primals_9, (2, 1, 16), (16, 16, 1))
640
+ assert_size_stride(primals_10, (2, ), (1, ))
641
+ assert_size_stride(primals_11, (2, 1, 16), (16, 16, 1))
642
+ assert_size_stride(primals_13, (2, 1, 16, s4), (16*s4, 16*s4, s4, 1))
643
+ assert_size_stride(primals_15, (2, 1, s56), (s56, s56, 1))
644
+ assert_size_stride(primals_17, (2, 1, s84, 16), (16*s84, 16*s84, 16, 1))
645
+ assert_size_stride(primals_19, (2, 1, s99), (s99, s99, 1))
646
+ assert_size_stride(primals_21, (2, 1, s6, 16), (16*s6, 16*s6, 16, 1))
647
+ with torch.cuda._DeviceGuard(6):
648
+ torch.cuda.set_device(6)
649
+ buf0 = empty_strided_cuda((2, 32, 2048), (65536, 2048, 1), torch.float32)
650
+ buf1 = empty_strided_cuda((2, 32, 2048), (65536, 2048, 1), torch.float32)
651
+ buf2 = empty_strided_cuda((2, 32, 2048, 128), (8388608, 128, 4096, 1), torch.bfloat16)
652
+ # Topologically Sorted Source Nodes: [flex_attention], Original ATen: []
653
+ stream6 = get_raw_stream(6)
654
+ triton_tem_fused_0.run(primals_1, primals_3, primals_5, buf0, buf1, primals_9, primals_7, primals_11, primals_13, primals_10, buf2, s0, s72, 16, 2, 32, stream=stream6)
655
+ del buf1
656
+ return (buf2, primals_1, primals_3, primals_5, primals_7, primals_9, primals_10, primals_11, primals_13, primals_15, primals_17, primals_19, primals_21, buf2, buf0, s0, s72, s4, s56, s84, s99, s6, )
657
+
658
+ runner = Runner(partitions=[])
659
+ call = runner.call
660
+ recursively_apply_fns = runner.recursively_apply_fns
661
+
662
+
663
+ def benchmark_compiled_module(times=10, repeat=10):
664
+ from torch._dynamo.testing import rand_strided
665
+ from torch._inductor.utils import print_performance
666
+ primals_1 = rand_strided((2, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:6', dtype=torch.bfloat16)
667
+ primals_2 = 4096
668
+ primals_3 = rand_strided((2, 8, 4096, 128), (4194304, 524288, 128, 1), device='cuda:6', dtype=torch.bfloat16)
669
+ primals_4 = 4096
670
+ primals_5 = rand_strided((2, 8, 4096, 128), (4194304, 524288, 128, 1), device='cuda:6', dtype=torch.bfloat16)
671
+ primals_6 = 32
672
+ primals_7 = rand_strided((2, 1, 16, 32), (512, 512, 32, 1), device='cuda:6', dtype=torch.int32)
673
+ primals_8 = 4096
674
+ primals_9 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:6', dtype=torch.int32)
675
+ primals_10 = rand_strided((2, ), (1, ), device='cuda:6', dtype=torch.int64)
676
+ primals_11 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:6', dtype=torch.int32)
677
+ primals_12 = 32
678
+ primals_13 = rand_strided((2, 1, 16, 32), (512, 512, 32, 1), device='cuda:6', dtype=torch.int32)
679
+ primals_14 = 32
680
+ primals_15 = rand_strided((2, 1, 32), (32, 32, 1), device='cuda:6', dtype=torch.int32)
681
+ primals_16 = 32
682
+ primals_17 = rand_strided((2, 1, 32, 16), (512, 512, 16, 1), device='cuda:6', dtype=torch.int32)
683
+ primals_18 = 32
684
+ primals_19 = rand_strided((2, 1, 32), (32, 32, 1), device='cuda:6', dtype=torch.int32)
685
+ primals_20 = 32
686
+ primals_21 = rand_strided((2, 1, 32, 16), (512, 512, 16, 1), device='cuda:6', dtype=torch.int32)
687
+ fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21])
688
+ return print_performance(fn, times=times, repeat=repeat)
689
+
690
+
691
+ if __name__ == "__main__":
692
+ from torch._inductor.wrapper_benchmark import compiled_module_main
693
+ compiled_module_main('None', benchmark_compiled_module)
SpecForge-ext/cache/compiled_kernels/7k/c7kogmtwjpemxq6qqxi6bohljmze6cjf34eo47hpufuxmpjep3yw.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AOT ID: ['4_backward']
2
+ from ctypes import c_void_p, c_long, c_int
3
+ import torch
4
+ import math
5
+ import random
6
+ import os
7
+ import tempfile
8
+ from math import inf, nan
9
+ from cmath import nanj
10
+ from torch._inductor.hooks import run_intermediate_hooks
11
+ from torch._inductor.utils import maybe_profile
12
+ from torch._inductor.codegen.memory_planning import _align as align
13
+ from torch import device, empty_strided
14
+ from torch._inductor.async_compile import AsyncCompile
15
+ from torch._inductor.select_algorithm import extern_kernels
16
+ import triton
17
+ import triton.language as tl
18
+ from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
19
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
20
+
21
+ aten = torch.ops.aten
22
+ inductor_ops = torch.ops.inductor
23
+ _quantized = torch.ops._quantized
24
+ assert_size_stride = torch._C._dynamo.guards.assert_size_stride
25
+ assert_alignment = torch._C._dynamo.guards.assert_alignment
26
+ empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
27
+ empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
28
+ empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
29
+ empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
30
+ empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
31
+ reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
32
+ alloc_from_pool = torch.ops.inductor._alloc_from_pool
33
+ async_compile = AsyncCompile()
34
+ empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
35
+
36
+
37
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/cr/ccr5s7nffy4cqd7a3lcq3cnv2prruzwzc7chchf776jguuqqh5bc.py
38
+ # Topologically Sorted Source Nodes: [squeeze_2, sin, getitem_1, sin_1, squeeze, cos, getitem, cos_1], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.slice_backward, aten.add]
39
+ # Source node to ATen node mapping:
40
+ # cos => squeeze_1
41
+ # cos_1 => unsqueeze
42
+ # getitem => index
43
+ # getitem_1 => index_1
44
+ # sin => squeeze_3
45
+ # sin_1 => unsqueeze_1
46
+ # squeeze => squeeze
47
+ # squeeze_2 => squeeze_2
48
+ # Graph fragment:
49
+ # %tangents_2 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24*s9, s24, 1]cuda:3" = PlaceHolder[target=tangents_2]
50
+ # %primals_8 : Tensor "i64[1, s9][s9, 1]cuda:3" = PlaceHolder[target=primals_8]
51
+ # %primals_6 : Tensor "bf16[1, 1, s79, s24][s96, s96, s24, 1]cuda:3" = PlaceHolder[target=primals_6]
52
+ # %primals_4 : Tensor "bf16[1, 1, s92, s24][s96, s96, s24, 1]cuda:3" = PlaceHolder[target=primals_4]
53
+ # %squeeze_2 : Tensor "bf16[1, s79, s24][s96, s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_6, 1), kwargs = {})
54
+ # %squeeze_3 : Tensor "bf16[s79, s24][s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze_2, 0), kwargs = {})
55
+ # %index_1 : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_3, [%primals_8]), kwargs = {})
56
+ # %unsqueeze_1 : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:3"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index_1, 1), kwargs = {})
57
+ # %mul_84 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24*s9, s24, 1]cuda:3"[num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, %unsqueeze_1), kwargs = {})
58
+ # %slice_5 : Tensor "bf16[s48, s48, s9, s24 - ((s24//2))][s24*s48*s9, s24*s9, s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%mul_84, 3, 0, %add_96), kwargs = {})
59
+ # %slice_6 : Tensor "bf16[s48, s48, s9, (s24//2)][s24*s48*s9, s24*s9, s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%mul_84, 3, %sub_72, %primals_2), kwargs = {})
60
+ # %neg_2 : Tensor "bf16[s48, s48, s9, s24 - ((s24//2))][s48*s9*Max(1, s24 - ((s24//2))), s9*Max(1, s24 - ((s24//2))), Max(1, s24 - ((s24//2))), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%slice_5,), kwargs = {})
61
+ # %full_default : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24*s9, s24, 1]cuda:3"[num_users=2] = call_function[target=torch.ops.aten.full.default](args = ([%primals_10, %primals_10, %primals_7, %primals_2], 0), kwargs = {dtype: torch.bfloat16, layout: torch.strided, device: cuda:3, pin_memory: False})
62
+ # %slice_scatter_default : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24*s9, s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%full_default, %neg_2, 3, %floordiv, 9223372036854775807), kwargs = {})
63
+ # %slice_scatter_default_1 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24*s9, s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%full_default, %slice_6, 3, 0, %floordiv), kwargs = {})
64
+ # %add_100 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24*s9, s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%slice_scatter_default, %slice_scatter_default_1), kwargs = {})
65
+ # %squeeze : Tensor "bf16[1, s92, s24][s96, s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_4, 1), kwargs = {})
66
+ # %squeeze_1 : Tensor "bf16[s92, s24][s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze, 0), kwargs = {})
67
+ # %index : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_1, [%primals_8]), kwargs = {})
68
+ # %unsqueeze : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:3"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index, 1), kwargs = {})
69
+ # %mul_85 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24*s9, s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, %unsqueeze), kwargs = {})
70
+ # %add_101 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24*s9, s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_100, %mul_85), kwargs = {})
71
+ # return %add_101
72
+ triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0 = async_compile.triton('triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0', '''
73
+ import triton
74
+ import triton.language as tl
75
+
76
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
77
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
78
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
79
+ triton_helpers.set_driver_to_gpu()
80
+
81
+ @triton_heuristics.pointwise(
82
+ size_hints={'x': 16777216},
83
+ filename=__file__,
84
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
85
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 6, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
86
+ min_elem_per_thread=0
87
+ )
88
+ @triton.jit
89
+ def triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, xnumel, XBLOCK : tl.constexpr):
90
+ xoffset = tl.program_id(0) * XBLOCK
91
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
92
+ xmask = xindex < xnumel
93
+ x0 = (xindex % ks0)
94
+ x3 = xindex
95
+ x1 = ((xindex // ks0) % ks1)
96
+ tmp31 = tl.load(in_ptr0 + (x3), xmask, eviction_policy='evict_last').to(tl.float32)
97
+ tmp32 = tl.load(in_ptr1 + (x1), xmask, eviction_policy='evict_last')
98
+ tmp0 = x0
99
+ tmp1 = ks0 // 2
100
+ tmp2 = tmp0 >= tmp1
101
+ tmp3 = tl.load(in_ptr0 + (x3 + (-1)*(ks0 // 2)), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
102
+ tmp4 = tl.load(in_ptr1 + (x1), tmp2 & xmask, eviction_policy='evict_last', other=0.0)
103
+ tmp5 = tl.broadcast_to(ks2, [XBLOCK])
104
+ tmp6 = tmp4 + tmp5
105
+ tmp7 = tmp4 < 0
106
+ tmp8 = tl.where(tmp7, tmp6, tmp4)
107
+ tl.device_assert(((0 <= tl.broadcast_to(tmp8, [XBLOCK])) & (tl.broadcast_to(tmp8, [XBLOCK]) < ks2)) | ~(tmp2 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp8, [XBLOCK]) < ks2")
108
+ tmp10 = tl.load(in_ptr2 + (x0 + (-1)*(ks0 // 2) + ks0*tmp8), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
109
+ tmp11 = tmp3 * tmp10
110
+ tmp12 = -tmp11
111
+ tmp13 = tl.full(tmp12.shape, 0.0, tmp12.dtype)
112
+ tmp14 = tl.where(tmp2, tmp12, tmp13)
113
+ tmp15 = 0.0
114
+ tmp16 = tl.where(tmp2, tmp14, tmp15)
115
+ tmp17 = tmp0 < tmp1
116
+ tmp18 = tl.load(in_ptr0 + (ks0 + x3 + (-1)*(ks0 // 2)), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
117
+ tmp19 = tl.load(in_ptr1 + (x1), tmp17 & xmask, eviction_policy='evict_last', other=0.0)
118
+ tmp20 = tl.broadcast_to(ks2, [XBLOCK])
119
+ tmp21 = tmp19 + tmp20
120
+ tmp22 = tmp19 < 0
121
+ tmp23 = tl.where(tmp22, tmp21, tmp19)
122
+ tl.device_assert(((0 <= tl.broadcast_to(tmp23, [XBLOCK])) & (tl.broadcast_to(tmp23, [XBLOCK]) < ks2)) | ~(tmp17 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp23, [XBLOCK]) < ks2")
123
+ tmp25 = tl.load(in_ptr2 + (ks0 + x0 + (-1)*(ks0 // 2) + ks0*tmp23), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
124
+ tmp26 = tmp18 * tmp25
125
+ tmp27 = tl.full(tmp26.shape, 0.0, tmp26.dtype)
126
+ tmp28 = tl.where(tmp17, tmp26, tmp27)
127
+ tmp29 = tl.where(tmp17, tmp28, tmp15)
128
+ tmp30 = tmp16 + tmp29
129
+ tmp33 = ks3
130
+ tmp34 = tmp32 + tmp33
131
+ tmp35 = tmp32 < 0
132
+ tmp36 = tl.where(tmp35, tmp34, tmp32)
133
+ tl.device_assert(((0 <= tmp36) & (tmp36 < ks3)) | ~(xmask), "index out of bounds: 0 <= tmp36 < ks3")
134
+ tmp38 = tl.load(in_ptr3 + (x0 + ks0*tmp36), xmask, eviction_policy='evict_last').to(tl.float32)
135
+ tmp39 = tmp31 * tmp38
136
+ tmp40 = tmp30 + tmp39
137
+ tl.store(out_ptr0 + (x3), tmp40, xmask)
138
+ ''', device_str='cuda')
139
+
140
+
141
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sc/cscnljzdwi65mf6bzwdkbxxogrdjjknvycbgzdyjhcz5fx6umlk2.py
142
+ # Topologically Sorted Source Nodes: [squeeze_2, sin, getitem_1, sin_1, squeeze, cos, getitem, cos_1], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.slice_backward, aten.add]
143
+ # Source node to ATen node mapping:
144
+ # cos => squeeze_1
145
+ # cos_1 => unsqueeze
146
+ # getitem => index
147
+ # getitem_1 => index_1
148
+ # sin => squeeze_3
149
+ # sin_1 => unsqueeze_1
150
+ # squeeze => squeeze
151
+ # squeeze_2 => squeeze_2
152
+ # Graph fragment:
153
+ # %tangents_1 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:3" = PlaceHolder[target=tangents_1]
154
+ # %primals_8 : Tensor "i64[1, s9][s9, 1]cuda:3" = PlaceHolder[target=primals_8]
155
+ # %primals_6 : Tensor "bf16[1, 1, s79, s24][s96, s96, s24, 1]cuda:3" = PlaceHolder[target=primals_6]
156
+ # %primals_4 : Tensor "bf16[1, 1, s92, s24][s96, s96, s24, 1]cuda:3" = PlaceHolder[target=primals_4]
157
+ # %squeeze_2 : Tensor "bf16[1, s79, s24][s96, s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_6, 1), kwargs = {})
158
+ # %squeeze_3 : Tensor "bf16[s79, s24][s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze_2, 0), kwargs = {})
159
+ # %index_1 : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_3, [%primals_8]), kwargs = {})
160
+ # %unsqueeze_1 : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:3"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index_1, 1), kwargs = {})
161
+ # %squeeze : Tensor "bf16[1, s92, s24][s96, s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_4, 1), kwargs = {})
162
+ # %squeeze_1 : Tensor "bf16[s92, s24][s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze, 0), kwargs = {})
163
+ # %index : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_1, [%primals_8]), kwargs = {})
164
+ # %unsqueeze : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:3"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index, 1), kwargs = {})
165
+ # %mul_86 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:3"[num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_1, %unsqueeze_1), kwargs = {})
166
+ # %slice_7 : Tensor "bf16[s48, s34, s9, s24 - ((s24//2))][s24*s34*s9, s24*s9, s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%mul_86, 3, 0, %sub_72), kwargs = {})
167
+ # %slice_8 : Tensor "bf16[s48, s34, s9, (s24//2)][s24*s34*s9, s24*s9, s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%mul_86, 3, %sub_72, %primals_2), kwargs = {})
168
+ # %neg_3 : Tensor "bf16[s48, s34, s9, s24 - ((s24//2))][s34*s9*Max(1, s24 - ((s24//2))), s9*Max(1, s24 - ((s24//2))), Max(1, s24 - ((s24//2))), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%slice_7,), kwargs = {})
169
+ # %full_default_2 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:3"[num_users=2] = call_function[target=torch.ops.aten.full.default](args = ([%primals_10, %primals_11, %primals_7, %primals_2], 0), kwargs = {dtype: torch.bfloat16, layout: torch.strided, device: cuda:3, pin_memory: False})
170
+ # %slice_scatter_default_2 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%full_default_2, %neg_3, 3, %floordiv, 9223372036854775807), kwargs = {})
171
+ # %slice_scatter_default_3 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%full_default_2, %slice_8, 3, 0, %floordiv), kwargs = {})
172
+ # %add_106 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%slice_scatter_default_2, %slice_scatter_default_3), kwargs = {})
173
+ # %mul_87 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_1, %unsqueeze), kwargs = {})
174
+ # %add_107 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_106, %mul_87), kwargs = {})
175
+ # return %add_107
176
+ triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1 = async_compile.triton('triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1', '''
177
+ import triton
178
+ import triton.language as tl
179
+
180
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
181
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
182
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
183
+ triton_helpers.set_driver_to_gpu()
184
+
185
+ @triton_heuristics.pointwise(
186
+ size_hints={'x': 67108864},
187
+ filename=__file__,
188
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
189
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 6, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
190
+ min_elem_per_thread=0
191
+ )
192
+ @triton.jit
193
+ def triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, xnumel, XBLOCK : tl.constexpr):
194
+ xoffset = tl.program_id(0) * XBLOCK
195
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
196
+ xmask = xindex < xnumel
197
+ x0 = (xindex % ks0)
198
+ x3 = xindex
199
+ x1 = ((xindex // ks0) % ks1)
200
+ tmp31 = tl.load(in_ptr0 + (x3), xmask, eviction_policy='evict_last').to(tl.float32)
201
+ tmp32 = tl.load(in_ptr1 + (x1), xmask, eviction_policy='evict_last')
202
+ tmp0 = x0
203
+ tmp1 = ks0 // 2
204
+ tmp2 = tmp0 >= tmp1
205
+ tmp3 = tl.load(in_ptr0 + (x3 + (-1)*(ks0 // 2)), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
206
+ tmp4 = tl.load(in_ptr1 + (x1), tmp2 & xmask, eviction_policy='evict_last', other=0.0)
207
+ tmp5 = tl.broadcast_to(ks2, [XBLOCK])
208
+ tmp6 = tmp4 + tmp5
209
+ tmp7 = tmp4 < 0
210
+ tmp8 = tl.where(tmp7, tmp6, tmp4)
211
+ tl.device_assert(((0 <= tl.broadcast_to(tmp8, [XBLOCK])) & (tl.broadcast_to(tmp8, [XBLOCK]) < ks2)) | ~(tmp2 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp8, [XBLOCK]) < ks2")
212
+ tmp10 = tl.load(in_ptr2 + (x0 + (-1)*(ks0 // 2) + ks0*tmp8), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
213
+ tmp11 = tmp3 * tmp10
214
+ tmp12 = -tmp11
215
+ tmp13 = tl.full(tmp12.shape, 0.0, tmp12.dtype)
216
+ tmp14 = tl.where(tmp2, tmp12, tmp13)
217
+ tmp15 = 0.0
218
+ tmp16 = tl.where(tmp2, tmp14, tmp15)
219
+ tmp17 = tmp0 < tmp1
220
+ tmp18 = tl.load(in_ptr0 + (ks0 + x3 + (-1)*(ks0 // 2)), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
221
+ tmp19 = tl.load(in_ptr1 + (x1), tmp17 & xmask, eviction_policy='evict_last', other=0.0)
222
+ tmp20 = tl.broadcast_to(ks2, [XBLOCK])
223
+ tmp21 = tmp19 + tmp20
224
+ tmp22 = tmp19 < 0
225
+ tmp23 = tl.where(tmp22, tmp21, tmp19)
226
+ tl.device_assert(((0 <= tl.broadcast_to(tmp23, [XBLOCK])) & (tl.broadcast_to(tmp23, [XBLOCK]) < ks2)) | ~(tmp17 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp23, [XBLOCK]) < ks2")
227
+ tmp25 = tl.load(in_ptr2 + (ks0 + x0 + (-1)*(ks0 // 2) + ks0*tmp23), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
228
+ tmp26 = tmp18 * tmp25
229
+ tmp27 = tl.full(tmp26.shape, 0.0, tmp26.dtype)
230
+ tmp28 = tl.where(tmp17, tmp26, tmp27)
231
+ tmp29 = tl.where(tmp17, tmp28, tmp15)
232
+ tmp30 = tmp16 + tmp29
233
+ tmp33 = ks3
234
+ tmp34 = tmp32 + tmp33
235
+ tmp35 = tmp32 < 0
236
+ tmp36 = tl.where(tmp35, tmp34, tmp32)
237
+ tl.device_assert(((0 <= tmp36) & (tmp36 < ks3)) | ~(xmask), "index out of bounds: 0 <= tmp36 < ks3")
238
+ tmp38 = tl.load(in_ptr3 + (x0 + ks0*tmp36), xmask, eviction_policy='evict_last').to(tl.float32)
239
+ tmp39 = tmp31 * tmp38
240
+ tmp40 = tmp30 + tmp39
241
+ tl.store(out_ptr0 + (x3), tmp40, xmask)
242
+ ''', device_str='cuda')
243
+
244
+
245
+ async_compile.wait(globals())
246
+ del async_compile
247
+
248
+ class Runner:
249
+ def __init__(self, partitions):
250
+ self.partitions = partitions
251
+
252
+ def recursively_apply_fns(self, fns):
253
+ new_callables = []
254
+ for fn, c in zip(fns, self.partitions):
255
+ new_callables.append(fn(c))
256
+ self.partitions = new_callables
257
+
258
+ def call(self, args):
259
+ primals_2, primals_7, primals_10, primals_11, primals_1, primals_3, primals_5, floordiv, add_96, primals_4, primals_6, primals_8, tangents_1, tangents_2 = args
260
+ args.clear()
261
+ s24 = primals_2
262
+ s9 = primals_7
263
+ s48 = primals_10
264
+ s34 = primals_11
265
+ s92 = primals_1
266
+ s96 = primals_3
267
+ s79 = primals_5
268
+ assert_size_stride(primals_4, (1, 1, s92, s24), (s96, s96, s24, 1))
269
+ assert_size_stride(primals_6, (1, 1, s79, s24), (s96, s96, s24, 1))
270
+ assert_size_stride(primals_8, (1, s9), (s9, 1))
271
+ assert_size_stride(tangents_1, (s48, s34, s9, s24), (s24*s34*s9, s24*s9, s24, 1))
272
+ assert_size_stride(tangents_2, (s48, s48, s9, s24), (s24*s48*s9, s24*s9, s24, 1))
273
+ with torch.cuda._DeviceGuard(3):
274
+ torch.cuda.set_device(3)
275
+ buf0 = empty_strided_cuda((s48, s48, s9, s24), (s24*s48*s9, s24*s9, s24, 1), torch.bfloat16)
276
+ # Topologically Sorted Source Nodes: [squeeze_2, sin, getitem_1, sin_1, squeeze, cos, getitem, cos_1], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.slice_backward, aten.add]
277
+ triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0_xnumel = s24*s9*s48*s48
278
+ stream3 = get_raw_stream(3)
279
+ triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0.run(tangents_2, primals_8, primals_6, primals_4, buf0, s24, s9, s79, s92, triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0_xnumel, stream=stream3)
280
+ del tangents_2
281
+ buf1 = empty_strided_cuda((s48, s34, s9, s24), (s24*s34*s9, s24*s9, s24, 1), torch.bfloat16)
282
+ # Topologically Sorted Source Nodes: [squeeze_2, sin, getitem_1, sin_1, squeeze, cos, getitem, cos_1], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.slice_backward, aten.add]
283
+ triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1_xnumel = s24*s34*s48*s9
284
+ stream3 = get_raw_stream(3)
285
+ triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1.run(tangents_1, primals_8, primals_6, primals_4, buf1, s24, s9, s79, s92, triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1_xnumel, stream=stream3)
286
+ del primals_4
287
+ del primals_6
288
+ del primals_8
289
+ del tangents_1
290
+ return (None, None, None, None, None, None, None, None, None, None, None, buf1, buf0, )
291
+
292
+ runner = Runner(partitions=[])
293
+ call = runner.call
294
+ recursively_apply_fns = runner.recursively_apply_fns
295
+
296
+
297
+ def benchmark_compiled_module(times=10, repeat=10):
298
+ from torch._dynamo.testing import rand_strided
299
+ from torch._inductor.utils import print_performance
300
+ primals_2 = 128
301
+ primals_7 = 2048
302
+ primals_10 = 8
303
+ primals_11 = 32
304
+ primals_1 = 2048
305
+ primals_3 = 5245440
306
+ primals_5 = 2048
307
+ floordiv = 64
308
+ add_96 = 64
309
+ primals_4 = rand_strided((1, 1, 2048, 128), (5245440, 5245440, 128, 1), device='cuda:3', dtype=torch.bfloat16)
310
+ primals_6 = rand_strided((1, 1, 2048, 128), (5245440, 5245440, 128, 1), device='cuda:3', dtype=torch.bfloat16)
311
+ primals_8 = rand_strided((1, 2048), (2048, 1), device='cuda:3', dtype=torch.int64)
312
+ tangents_1 = rand_strided((8, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:3', dtype=torch.bfloat16)
313
+ tangents_2 = rand_strided((8, 8, 2048, 128), (2097152, 262144, 128, 1), device='cuda:3', dtype=torch.bfloat16)
314
+ fn = lambda: call([primals_2, primals_7, primals_10, primals_11, primals_1, primals_3, primals_5, floordiv, add_96, primals_4, primals_6, primals_8, tangents_1, tangents_2])
315
+ return print_performance(fn, times=times, repeat=repeat)
316
+
317
+
318
+ if __name__ == "__main__":
319
+ from torch._inductor.wrapper_benchmark import compiled_module_main
320
+ compiled_module_main('None', benchmark_compiled_module)
SpecForge-ext/cache/compiled_kernels/7p/c7ph4dk7ghsg37h7a46klnkhb6rck4rpgxyqg7fjyewxnxqk5vvs.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.reduction(
11
+ size_hints={'x': 16384, 'r0_': 32768},
12
+ reduction_hint=ReductionHint.INNER,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 262144, 'r0_': 1048576000}}
16
+ )
17
+ @triton.jit
18
+ def triton_red_fused_argmax_0(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
19
+ xnumel = 16384
20
+ r0_numel = 32000
21
+ rnumel = r0_numel
22
+ RBLOCK: tl.constexpr = R0_BLOCK
23
+ xoffset = tl.program_id(0) * XBLOCK
24
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
25
+ xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
26
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
27
+ rbase = r0_base
28
+ x0 = xindex
29
+ _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32)
30
+ _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32)
31
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
32
+ r0_index = r0_offset + r0_base
33
+ r0_mask = r0_index < r0_numel
34
+ roffset = r0_offset
35
+ rindex = r0_index
36
+ r0_1 = r0_index
37
+ tmp0 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
38
+ tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
39
+ _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index(
40
+ _tmp2, _tmp2_index, tmp1, rindex
41
+ )
42
+ _tmp2 = tl.where(r0_mask, _tmp2_next, _tmp2)
43
+ _tmp2_index = tl.where(r0_mask, _tmp2_index_next, _tmp2_index)
44
+ tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1)
45
+ tmp2 = tmp2_idx[:, None]
46
+ tl.store(out_ptr0 + (x0), tmp2, None)
SpecForge-ext/cache/compiled_kernels/ag/caglk6whzazaqxxtfwcwjz3xhkspqbhu4cpbiwsvmmwxpmmmtst6.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AOT ID: ['11_inference']
2
+ from ctypes import c_void_p, c_long, c_int
3
+ import torch
4
+ import math
5
+ import random
6
+ import os
7
+ import tempfile
8
+ from math import inf, nan
9
+ from cmath import nanj
10
+ from torch._inductor.hooks import run_intermediate_hooks
11
+ from torch._inductor.utils import maybe_profile
12
+ from torch._inductor.codegen.memory_planning import _align as align
13
+ from torch import device, empty_strided
14
+ from torch._inductor.async_compile import AsyncCompile
15
+ from torch._inductor.select_algorithm import extern_kernels
16
+ import triton
17
+ import triton.language as tl
18
+ from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
19
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
20
+
21
+ aten = torch.ops.aten
22
+ inductor_ops = torch.ops.inductor
23
+ _quantized = torch.ops._quantized
24
+ assert_size_stride = torch._C._dynamo.guards.assert_size_stride
25
+ assert_alignment = torch._C._dynamo.guards.assert_alignment
26
+ empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
27
+ empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
28
+ empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
29
+ empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
30
+ empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
31
+ reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
32
+ alloc_from_pool = torch.ops.inductor._alloc_from_pool
33
+ async_compile = AsyncCompile()
34
+ empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
35
+
36
+
37
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/km/ckmybvsvduh5cqerbakqni4rsg2ms7xz5hoaifsmkr3dxydk73sv.py
38
+ # Topologically Sorted Source Nodes: [target_head, target_p], Original ATen: [aten._to_copy, prims.prepare_softmax_online, aten.sub, aten.exp, aten._softmax]
39
+ # Source node to ATen node mapping:
40
+ # target_head => convert_element_type
41
+ # target_p => div
42
+ # Graph fragment:
43
+ # %arg1_1 : Tensor "bf16[2, s67, 32000][32000*s67, 32000, 1]cuda:5" = PlaceHolder[target=arg1_1]
44
+ # %getitem : Tensor "f32[2, s67, 1][s67, 1, 2*s67]cuda:5" = PlaceHolder[target=getitem]
45
+ # %getitem_1 : Tensor "f32[2, s67, 1][s67, 1, 2*s67]cuda:5" = PlaceHolder[target=getitem_1]
46
+ # %convert_element_type : Tensor "f32[2, s67, 32000][32000*s67, 32000, 1]cuda:5"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%arg1_1, torch.float32), kwargs = {})
47
+ # %prepare_softmax_online_default : [num_users=2] = call_function[target=torch.ops.prims.prepare_softmax_online.default](args = (%convert_element_type, 2), kwargs = {})
48
+ # %sub_tensor : Tensor "f32[2, s67, 32000][32000*s67, 32000, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%convert_element_type, %getitem), kwargs = {})
49
+ # %exp_default : Tensor "f32[2, s67, 32000][32000*s67, 32000, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.exp.default](args = (%sub_tensor,), kwargs = {})
50
+ # %div : Tensor "f32[2, s67, 32000][32000*s67, 32000, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%exp_default, %getitem_1), kwargs = {})
51
+ # return %getitem,%getitem_1,%div
52
+ triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0 = async_compile.triton('triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0', '''
53
+ import triton
54
+ import triton.language as tl
55
+
56
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
57
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
58
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
59
+ triton_helpers.set_driver_to_gpu()
60
+
61
+ @triton_heuristics.reduction(
62
+ size_hints={'x': 4096, 'r0_': 32768},
63
+ reduction_hint=ReductionHint.INNER,
64
+ filename=__file__,
65
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr2': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
66
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
67
+ )
68
+ @triton.jit
69
+ def triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0(in_ptr0, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
70
+ r0_numel = 32000
71
+ rnumel = r0_numel
72
+ RBLOCK: tl.constexpr = R0_BLOCK
73
+ xoffset = tl.program_id(0) * XBLOCK
74
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
75
+ xmask = xindex < xnumel
76
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
77
+ rbase = r0_base
78
+ x0 = xindex
79
+ _tmp3_max = tl.full([XBLOCK, R0_BLOCK], float('-inf'), tl.float32)
80
+ _tmp3_sum = tl.zeros([XBLOCK, R0_BLOCK], tl.float32)
81
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
82
+ r0_index = r0_offset + r0_base
83
+ r0_mask = r0_index < r0_numel
84
+ roffset = r0_offset
85
+ rindex = r0_index
86
+ r0_1 = r0_index
87
+ tmp0 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
88
+ tmp1 = tmp0.to(tl.float32)
89
+ tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK])
90
+
91
+ _tmp3_max_next, _tmp3_sum_next = triton_helpers.online_softmax_combine(
92
+ _tmp3_max, _tmp3_sum, tmp2, False
93
+ )
94
+
95
+ _tmp3_max = tl.where(r0_mask & xmask, _tmp3_max_next, _tmp3_max)
96
+ _tmp3_sum = tl.where(r0_mask & xmask, _tmp3_sum_next, _tmp3_sum)
97
+
98
+ tmp3, tmp4 = triton_helpers.online_softmax_reduce(
99
+ _tmp3_max, _tmp3_sum, 1, False)
100
+ tmp3 = tmp3[:, None]
101
+ tmp4 = tmp4[:, None]
102
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
103
+ r0_index = r0_offset + r0_base
104
+ r0_mask = r0_index < r0_numel
105
+ roffset = r0_offset
106
+ rindex = r0_index
107
+ r0_1 = r0_index
108
+ tmp5 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
109
+ tmp6 = tmp5.to(tl.float32)
110
+ tmp7 = tmp6 - tmp3
111
+ tmp8 = libdevice.exp(tmp7)
112
+ tmp9 = (tmp8 / tmp4)
113
+ tl.store(out_ptr2 + (r0_1 + 32000*x0), tmp9, r0_mask & xmask)
114
+ ''', device_str='cuda')
115
+
116
+
117
+ async_compile.wait(globals())
118
+ del async_compile
119
+
120
+ class Runner:
121
+ def __init__(self, partitions):
122
+ self.partitions = partitions
123
+
124
+ def recursively_apply_fns(self, fns):
125
+ new_callables = []
126
+ for fn, c in zip(fns, self.partitions):
127
+ new_callables.append(fn(c))
128
+ self.partitions = new_callables
129
+
130
+ def call(self, args):
131
+ arg0_1, arg1_1 = args
132
+ args.clear()
133
+ s67 = arg0_1
134
+ assert_size_stride(arg1_1, (2, s67, 32000), (32000*s67, 32000, 1))
135
+ with torch.cuda._DeviceGuard(5):
136
+ torch.cuda.set_device(5)
137
+ buf2 = empty_strided_cuda((2, s67, 32000), (32000*s67, 32000, 1), torch.float32)
138
+ # Topologically Sorted Source Nodes: [target_head, target_p], Original ATen: [aten._to_copy, prims.prepare_softmax_online, aten.sub, aten.exp, aten._softmax]
139
+ triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0_xnumel = 2*s67
140
+ stream5 = get_raw_stream(5)
141
+ triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0.run(arg1_1, buf2, triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0_xnumel, 32000, stream=stream5)
142
+ del arg1_1
143
+ return (buf2, )
144
+
145
+ runner = Runner(partitions=[])
146
+ call = runner.call
147
+ recursively_apply_fns = runner.recursively_apply_fns
148
+
149
+
150
+ def benchmark_compiled_module(times=10, repeat=10):
151
+ from torch._dynamo.testing import rand_strided
152
+ from torch._inductor.utils import print_performance
153
+ arg0_1 = 1569
154
+ arg1_1 = rand_strided((2, 1569, 32000), (50208000, 32000, 1), device='cuda:5', dtype=torch.bfloat16)
155
+ fn = lambda: call([arg0_1, arg1_1])
156
+ return print_performance(fn, times=times, repeat=repeat)
157
+
158
+
159
+ if __name__ == "__main__":
160
+ from torch._inductor.wrapper_benchmark import compiled_module_main
161
+ compiled_module_main('None', benchmark_compiled_module)
SpecForge-ext/cache/compiled_kernels/ao/caoqvgzvbk7exhnvkuijsznlx2ebywfk6vitynyaomz5hgx5szk5.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.reduction(
11
+ size_hints={'x': 4096, 'r0_': 32768},
12
+ reduction_hint=ReductionHint.INNER,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr2': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
16
+ )
17
+ @triton.jit
18
+ def triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0(in_ptr0, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
19
+ r0_numel = 32000
20
+ rnumel = r0_numel
21
+ RBLOCK: tl.constexpr = R0_BLOCK
22
+ xoffset = tl.program_id(0) * XBLOCK
23
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
24
+ xmask = xindex < xnumel
25
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
26
+ rbase = r0_base
27
+ x0 = xindex
28
+ _tmp3_max = tl.full([XBLOCK, R0_BLOCK], float('-inf'), tl.float32)
29
+ _tmp3_sum = tl.zeros([XBLOCK, R0_BLOCK], tl.float32)
30
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
31
+ r0_index = r0_offset + r0_base
32
+ r0_mask = r0_index < r0_numel
33
+ roffset = r0_offset
34
+ rindex = r0_index
35
+ r0_1 = r0_index
36
+ tmp0 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
37
+ tmp1 = tmp0.to(tl.float32)
38
+ tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK])
39
+
40
+ _tmp3_max_next, _tmp3_sum_next = triton_helpers.online_softmax_combine(
41
+ _tmp3_max, _tmp3_sum, tmp2, False
42
+ )
43
+
44
+ _tmp3_max = tl.where(r0_mask & xmask, _tmp3_max_next, _tmp3_max)
45
+ _tmp3_sum = tl.where(r0_mask & xmask, _tmp3_sum_next, _tmp3_sum)
46
+
47
+ tmp3, tmp4 = triton_helpers.online_softmax_reduce(
48
+ _tmp3_max, _tmp3_sum, 1, False)
49
+ tmp3 = tmp3[:, None]
50
+ tmp4 = tmp4[:, None]
51
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
52
+ r0_index = r0_offset + r0_base
53
+ r0_mask = r0_index < r0_numel
54
+ roffset = r0_offset
55
+ rindex = r0_index
56
+ r0_1 = r0_index
57
+ tmp5 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
58
+ tmp6 = tmp5.to(tl.float32)
59
+ tmp7 = tmp6 - tmp3
60
+ tmp8 = libdevice.exp(tmp7)
61
+ tmp9 = (tmp8 / tmp4)
62
+ tl.store(out_ptr2 + (r0_1 + 32000*x0), tmp9, r0_mask & xmask)
SpecForge-ext/cache/compiled_kernels/aw/cawxo2ohlu2xus3es5wun6g3qdjlbckp23dho2fo6p76pf7ogcso.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AOT ID: ['4_backward']
2
+ from ctypes import c_void_p, c_long, c_int
3
+ import torch
4
+ import math
5
+ import random
6
+ import os
7
+ import tempfile
8
+ from math import inf, nan
9
+ from cmath import nanj
10
+ from torch._inductor.hooks import run_intermediate_hooks
11
+ from torch._inductor.utils import maybe_profile
12
+ from torch._inductor.codegen.memory_planning import _align as align
13
+ from torch import device, empty_strided
14
+ from torch._inductor.async_compile import AsyncCompile
15
+ from torch._inductor.select_algorithm import extern_kernels
16
+ import triton
17
+ import triton.language as tl
18
+ from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
19
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
20
+
21
+ aten = torch.ops.aten
22
+ inductor_ops = torch.ops.inductor
23
+ _quantized = torch.ops._quantized
24
+ assert_size_stride = torch._C._dynamo.guards.assert_size_stride
25
+ assert_alignment = torch._C._dynamo.guards.assert_alignment
26
+ empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
27
+ empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
28
+ empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
29
+ empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
30
+ empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
31
+ reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
32
+ alloc_from_pool = torch.ops.inductor._alloc_from_pool
33
+ async_compile = AsyncCompile()
34
+ empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
35
+
36
+
37
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/vq/cvqvxwsz5trm7yg2d2gcqm3fnjjobjar5tizng43rigkxges3nhj.py
38
+ # Topologically Sorted Source Nodes: [squeeze_2, sin, getitem_1, sin_1, squeeze, cos, getitem, cos_1], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.slice_backward, aten.add]
39
+ # Source node to ATen node mapping:
40
+ # cos => squeeze_1
41
+ # cos_1 => unsqueeze
42
+ # getitem => index
43
+ # getitem_1 => index_1
44
+ # sin => squeeze_3
45
+ # sin_1 => unsqueeze_1
46
+ # squeeze => squeeze
47
+ # squeeze_2 => squeeze_2
48
+ # Graph fragment:
49
+ # %tangents_2 : Tensor "bf16[s48, s25, s9, s24][s24*s25*s9, s24*s9, s24, 1]cuda:0" = PlaceHolder[target=tangents_2]
50
+ # %primals_8 : Tensor "i64[1, s9][s9, 1]cuda:0" = PlaceHolder[target=primals_8]
51
+ # %primals_6 : Tensor "bf16[1, 1, s79, s24][s96, s96, s24, 1]cuda:0" = PlaceHolder[target=primals_6]
52
+ # %primals_4 : Tensor "bf16[1, 1, s92, s24][s96, s96, s24, 1]cuda:0" = PlaceHolder[target=primals_4]
53
+ # %squeeze_2 : Tensor "bf16[1, s79, s24][s96, s24, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_6, 1), kwargs = {})
54
+ # %squeeze_3 : Tensor "bf16[s79, s24][s24, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze_2, 0), kwargs = {})
55
+ # %index_1 : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_3, [%primals_8]), kwargs = {})
56
+ # %unsqueeze_1 : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index_1, 1), kwargs = {})
57
+ # %mul_84 : Tensor "bf16[s48, s25, s9, s24][s24*s25*s9, s24*s9, s24, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, %unsqueeze_1), kwargs = {})
58
+ # %slice_5 : Tensor "bf16[s48, s25, s9, s24 - ((s24//2))][s24*s25*s9, s24*s9, s24, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%mul_84, 3, 0, %add_96), kwargs = {})
59
+ # %slice_6 : Tensor "bf16[s48, s25, s9, (s24//2)][s24*s25*s9, s24*s9, s24, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%mul_84, 3, %sub_72, %primals_2), kwargs = {})
60
+ # %neg_2 : Tensor "bf16[s48, s25, s9, s24 - ((s24//2))][s25*s9*Max(1, s24 - ((s24//2))), s9*Max(1, s24 - ((s24//2))), Max(1, s24 - ((s24//2))), 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%slice_5,), kwargs = {})
61
+ # %full_default : Tensor "bf16[s48, s25, s9, s24][s24*s25*s9, s24*s9, s24, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.full.default](args = ([%primals_10, %primals_13, %primals_7, %primals_2], 0), kwargs = {dtype: torch.bfloat16, layout: torch.strided, device: cuda:0, pin_memory: False})
62
+ # %slice_scatter_default : Tensor "bf16[s48, s25, s9, s24][s24*s25*s9, s24*s9, s24, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%full_default, %neg_2, 3, %floordiv, 9223372036854775807), kwargs = {})
63
+ # %slice_scatter_default_1 : Tensor "bf16[s48, s25, s9, s24][s24*s25*s9, s24*s9, s24, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%full_default, %slice_6, 3, 0, %floordiv), kwargs = {})
64
+ # %add_100 : Tensor "bf16[s48, s25, s9, s24][s24*s25*s9, s24*s9, s24, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%slice_scatter_default, %slice_scatter_default_1), kwargs = {})
65
+ # %squeeze : Tensor "bf16[1, s92, s24][s96, s24, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_4, 1), kwargs = {})
66
+ # %squeeze_1 : Tensor "bf16[s92, s24][s24, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze, 0), kwargs = {})
67
+ # %index : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_1, [%primals_8]), kwargs = {})
68
+ # %unsqueeze : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index, 1), kwargs = {})
69
+ # %mul_85 : Tensor "bf16[s48, s25, s9, s24][s24*s25*s9, s24*s9, s24, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, %unsqueeze), kwargs = {})
70
+ # %add_101 : Tensor "bf16[s48, s25, s9, s24][s24*s25*s9, s24*s9, s24, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_100, %mul_85), kwargs = {})
71
+ # return %add_101
72
+ triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0 = async_compile.triton('triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0', '''
73
+ import triton
74
+ import triton.language as tl
75
+
76
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
77
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
78
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
79
+ triton_helpers.set_driver_to_gpu()
80
+
81
+ @triton_heuristics.pointwise(
82
+ size_hints={'x': 4194304},
83
+ filename=__file__,
84
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
85
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 6, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
86
+ min_elem_per_thread=0
87
+ )
88
+ @triton.jit
89
+ def triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, xnumel, XBLOCK : tl.constexpr):
90
+ xoffset = tl.program_id(0) * XBLOCK
91
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
92
+ xmask = xindex < xnumel
93
+ x0 = (xindex % ks0)
94
+ x3 = xindex
95
+ x1 = ((xindex // ks0) % ks1)
96
+ tmp31 = tl.load(in_ptr0 + (x3), xmask, eviction_policy='evict_last').to(tl.float32)
97
+ tmp32 = tl.load(in_ptr1 + (x1), xmask, eviction_policy='evict_last')
98
+ tmp0 = x0
99
+ tmp1 = ks0 // 2
100
+ tmp2 = tmp0 >= tmp1
101
+ tmp3 = tl.load(in_ptr0 + (x3 + (-1)*(ks0 // 2)), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
102
+ tmp4 = tl.load(in_ptr1 + (x1), tmp2 & xmask, eviction_policy='evict_last', other=0.0)
103
+ tmp5 = tl.broadcast_to(ks2, [XBLOCK])
104
+ tmp6 = tmp4 + tmp5
105
+ tmp7 = tmp4 < 0
106
+ tmp8 = tl.where(tmp7, tmp6, tmp4)
107
+ tl.device_assert(((0 <= tl.broadcast_to(tmp8, [XBLOCK])) & (tl.broadcast_to(tmp8, [XBLOCK]) < ks2)) | ~(tmp2 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp8, [XBLOCK]) < ks2")
108
+ tmp10 = tl.load(in_ptr2 + (x0 + (-1)*(ks0 // 2) + ks0*tmp8), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
109
+ tmp11 = tmp3 * tmp10
110
+ tmp12 = -tmp11
111
+ tmp13 = tl.full(tmp12.shape, 0.0, tmp12.dtype)
112
+ tmp14 = tl.where(tmp2, tmp12, tmp13)
113
+ tmp15 = 0.0
114
+ tmp16 = tl.where(tmp2, tmp14, tmp15)
115
+ tmp17 = tmp0 < tmp1
116
+ tmp18 = tl.load(in_ptr0 + (ks0 + x3 + (-1)*(ks0 // 2)), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
117
+ tmp19 = tl.load(in_ptr1 + (x1), tmp17 & xmask, eviction_policy='evict_last', other=0.0)
118
+ tmp20 = tl.broadcast_to(ks2, [XBLOCK])
119
+ tmp21 = tmp19 + tmp20
120
+ tmp22 = tmp19 < 0
121
+ tmp23 = tl.where(tmp22, tmp21, tmp19)
122
+ tl.device_assert(((0 <= tl.broadcast_to(tmp23, [XBLOCK])) & (tl.broadcast_to(tmp23, [XBLOCK]) < ks2)) | ~(tmp17 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp23, [XBLOCK]) < ks2")
123
+ tmp25 = tl.load(in_ptr2 + (ks0 + x0 + (-1)*(ks0 // 2) + ks0*tmp23), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
124
+ tmp26 = tmp18 * tmp25
125
+ tmp27 = tl.full(tmp26.shape, 0.0, tmp26.dtype)
126
+ tmp28 = tl.where(tmp17, tmp26, tmp27)
127
+ tmp29 = tl.where(tmp17, tmp28, tmp15)
128
+ tmp30 = tmp16 + tmp29
129
+ tmp33 = ks3
130
+ tmp34 = tmp32 + tmp33
131
+ tmp35 = tmp32 < 0
132
+ tmp36 = tl.where(tmp35, tmp34, tmp32)
133
+ tl.device_assert(((0 <= tmp36) & (tmp36 < ks3)) | ~(xmask), "index out of bounds: 0 <= tmp36 < ks3")
134
+ tmp38 = tl.load(in_ptr3 + (x0 + ks0*tmp36), xmask, eviction_policy='evict_last').to(tl.float32)
135
+ tmp39 = tmp31 * tmp38
136
+ tmp40 = tmp30 + tmp39
137
+ tl.store(out_ptr0 + (x3), tmp40, xmask)
138
+ ''', device_str='cuda')
139
+
140
+
141
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/zu/czu2jyesrdsgfrod6l7j2iof2pn657e57odk5qfyk2zi2uaqndjj.py
142
+ # Topologically Sorted Source Nodes: [squeeze_2, sin, getitem_1, sin_1, squeeze, cos, getitem, cos_1], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.slice_backward, aten.add]
143
+ # Source node to ATen node mapping:
144
+ # cos => squeeze_1
145
+ # cos_1 => unsqueeze
146
+ # getitem => index
147
+ # getitem_1 => index_1
148
+ # sin => squeeze_3
149
+ # sin_1 => unsqueeze_1
150
+ # squeeze => squeeze
151
+ # squeeze_2 => squeeze_2
152
+ # Graph fragment:
153
+ # %tangents_1 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:0" = PlaceHolder[target=tangents_1]
154
+ # %primals_8 : Tensor "i64[1, s9][s9, 1]cuda:0" = PlaceHolder[target=primals_8]
155
+ # %primals_6 : Tensor "bf16[1, 1, s79, s24][s96, s96, s24, 1]cuda:0" = PlaceHolder[target=primals_6]
156
+ # %primals_4 : Tensor "bf16[1, 1, s92, s24][s96, s96, s24, 1]cuda:0" = PlaceHolder[target=primals_4]
157
+ # %squeeze_2 : Tensor "bf16[1, s79, s24][s96, s24, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_6, 1), kwargs = {})
158
+ # %squeeze_3 : Tensor "bf16[s79, s24][s24, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze_2, 0), kwargs = {})
159
+ # %index_1 : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_3, [%primals_8]), kwargs = {})
160
+ # %unsqueeze_1 : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index_1, 1), kwargs = {})
161
+ # %squeeze : Tensor "bf16[1, s92, s24][s96, s24, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_4, 1), kwargs = {})
162
+ # %squeeze_1 : Tensor "bf16[s92, s24][s24, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze, 0), kwargs = {})
163
+ # %index : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_1, [%primals_8]), kwargs = {})
164
+ # %unsqueeze : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index, 1), kwargs = {})
165
+ # %mul_86 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_1, %unsqueeze_1), kwargs = {})
166
+ # %slice_7 : Tensor "bf16[s48, s34, s9, s24 - ((s24//2))][s24*s34*s9, s24*s9, s24, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%mul_86, 3, 0, %sub_72), kwargs = {})
167
+ # %slice_8 : Tensor "bf16[s48, s34, s9, (s24//2)][s24*s34*s9, s24*s9, s24, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%mul_86, 3, %sub_72, %primals_2), kwargs = {})
168
+ # %neg_3 : Tensor "bf16[s48, s34, s9, s24 - ((s24//2))][s34*s9*Max(1, s24 - ((s24//2))), s9*Max(1, s24 - ((s24//2))), Max(1, s24 - ((s24//2))), 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%slice_7,), kwargs = {})
169
+ # %full_default_2 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.full.default](args = ([%primals_10, %primals_11, %primals_7, %primals_2], 0), kwargs = {dtype: torch.bfloat16, layout: torch.strided, device: cuda:0, pin_memory: False})
170
+ # %slice_scatter_default_2 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%full_default_2, %neg_3, 3, %floordiv, 9223372036854775807), kwargs = {})
171
+ # %slice_scatter_default_3 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%full_default_2, %slice_8, 3, 0, %floordiv), kwargs = {})
172
+ # %add_106 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%slice_scatter_default_2, %slice_scatter_default_3), kwargs = {})
173
+ # %mul_87 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_1, %unsqueeze), kwargs = {})
174
+ # %add_107 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_106, %mul_87), kwargs = {})
175
+ # return %add_107
176
+ triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1 = async_compile.triton('triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1', '''
177
+ import triton
178
+ import triton.language as tl
179
+
180
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
181
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
182
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
183
+ triton_helpers.set_driver_to_gpu()
184
+
185
+ @triton_heuristics.pointwise(
186
+ size_hints={'x': 16777216},
187
+ filename=__file__,
188
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
189
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 6, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
190
+ min_elem_per_thread=0
191
+ )
192
+ @triton.jit
193
+ def triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, xnumel, XBLOCK : tl.constexpr):
194
+ xoffset = tl.program_id(0) * XBLOCK
195
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
196
+ xmask = xindex < xnumel
197
+ x0 = (xindex % ks0)
198
+ x3 = xindex
199
+ x1 = ((xindex // ks0) % ks1)
200
+ tmp31 = tl.load(in_ptr0 + (x3), xmask, eviction_policy='evict_last').to(tl.float32)
201
+ tmp32 = tl.load(in_ptr1 + (x1), xmask, eviction_policy='evict_last')
202
+ tmp0 = x0
203
+ tmp1 = ks0 // 2
204
+ tmp2 = tmp0 >= tmp1
205
+ tmp3 = tl.load(in_ptr0 + (x3 + (-1)*(ks0 // 2)), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
206
+ tmp4 = tl.load(in_ptr1 + (x1), tmp2 & xmask, eviction_policy='evict_last', other=0.0)
207
+ tmp5 = tl.broadcast_to(ks2, [XBLOCK])
208
+ tmp6 = tmp4 + tmp5
209
+ tmp7 = tmp4 < 0
210
+ tmp8 = tl.where(tmp7, tmp6, tmp4)
211
+ tl.device_assert(((0 <= tl.broadcast_to(tmp8, [XBLOCK])) & (tl.broadcast_to(tmp8, [XBLOCK]) < ks2)) | ~(tmp2 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp8, [XBLOCK]) < ks2")
212
+ tmp10 = tl.load(in_ptr2 + (x0 + (-1)*(ks0 // 2) + ks0*tmp8), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
213
+ tmp11 = tmp3 * tmp10
214
+ tmp12 = -tmp11
215
+ tmp13 = tl.full(tmp12.shape, 0.0, tmp12.dtype)
216
+ tmp14 = tl.where(tmp2, tmp12, tmp13)
217
+ tmp15 = 0.0
218
+ tmp16 = tl.where(tmp2, tmp14, tmp15)
219
+ tmp17 = tmp0 < tmp1
220
+ tmp18 = tl.load(in_ptr0 + (ks0 + x3 + (-1)*(ks0 // 2)), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
221
+ tmp19 = tl.load(in_ptr1 + (x1), tmp17 & xmask, eviction_policy='evict_last', other=0.0)
222
+ tmp20 = tl.broadcast_to(ks2, [XBLOCK])
223
+ tmp21 = tmp19 + tmp20
224
+ tmp22 = tmp19 < 0
225
+ tmp23 = tl.where(tmp22, tmp21, tmp19)
226
+ tl.device_assert(((0 <= tl.broadcast_to(tmp23, [XBLOCK])) & (tl.broadcast_to(tmp23, [XBLOCK]) < ks2)) | ~(tmp17 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp23, [XBLOCK]) < ks2")
227
+ tmp25 = tl.load(in_ptr2 + (ks0 + x0 + (-1)*(ks0 // 2) + ks0*tmp23), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
228
+ tmp26 = tmp18 * tmp25
229
+ tmp27 = tl.full(tmp26.shape, 0.0, tmp26.dtype)
230
+ tmp28 = tl.where(tmp17, tmp26, tmp27)
231
+ tmp29 = tl.where(tmp17, tmp28, tmp15)
232
+ tmp30 = tmp16 + tmp29
233
+ tmp33 = ks3
234
+ tmp34 = tmp32 + tmp33
235
+ tmp35 = tmp32 < 0
236
+ tmp36 = tl.where(tmp35, tmp34, tmp32)
237
+ tl.device_assert(((0 <= tmp36) & (tmp36 < ks3)) | ~(xmask), "index out of bounds: 0 <= tmp36 < ks3")
238
+ tmp38 = tl.load(in_ptr3 + (x0 + ks0*tmp36), xmask, eviction_policy='evict_last').to(tl.float32)
239
+ tmp39 = tmp31 * tmp38
240
+ tmp40 = tmp30 + tmp39
241
+ tl.store(out_ptr0 + (x3), tmp40, xmask)
242
+ ''', device_str='cuda')
243
+
244
+
245
+ async_compile.wait(globals())
246
+ del async_compile
247
+
248
+ class Runner:
249
+ def __init__(self, partitions):
250
+ self.partitions = partitions
251
+
252
+ def recursively_apply_fns(self, fns):
253
+ new_callables = []
254
+ for fn, c in zip(fns, self.partitions):
255
+ new_callables.append(fn(c))
256
+ self.partitions = new_callables
257
+
258
+ def call(self, args):
259
+ primals_2, primals_7, primals_10, primals_11, primals_13, primals_1, primals_3, primals_5, floordiv, add_96, primals_4, primals_6, primals_8, tangents_1, tangents_2 = args
260
+ args.clear()
261
+ s24 = primals_2
262
+ s9 = primals_7
263
+ s48 = primals_10
264
+ s34 = primals_11
265
+ s25 = primals_13
266
+ s92 = primals_1
267
+ s96 = primals_3
268
+ s79 = primals_5
269
+ assert_size_stride(primals_4, (1, 1, s92, s24), (s96, s96, s24, 1))
270
+ assert_size_stride(primals_6, (1, 1, s79, s24), (s96, s96, s24, 1))
271
+ assert_size_stride(primals_8, (1, s9), (s9, 1))
272
+ assert_size_stride(tangents_1, (s48, s34, s9, s24), (s24*s34*s9, s24*s9, s24, 1))
273
+ assert_size_stride(tangents_2, (s48, s25, s9, s24), (s24*s25*s9, s24*s9, s24, 1))
274
+ with torch.cuda._DeviceGuard(0):
275
+ torch.cuda.set_device(0)
276
+ buf0 = empty_strided_cuda((s48, s25, s9, s24), (s24*s25*s9, s24*s9, s24, 1), torch.bfloat16)
277
+ # Topologically Sorted Source Nodes: [squeeze_2, sin, getitem_1, sin_1, squeeze, cos, getitem, cos_1], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.slice_backward, aten.add]
278
+ triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0_xnumel = s24*s25*s48*s9
279
+ stream0 = get_raw_stream(0)
280
+ triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0.run(tangents_2, primals_8, primals_6, primals_4, buf0, s24, s9, s79, s92, triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0_xnumel, stream=stream0)
281
+ del tangents_2
282
+ buf1 = empty_strided_cuda((s48, s34, s9, s24), (s24*s34*s9, s24*s9, s24, 1), torch.bfloat16)
283
+ # Topologically Sorted Source Nodes: [squeeze_2, sin, getitem_1, sin_1, squeeze, cos, getitem, cos_1], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.slice_backward, aten.add]
284
+ triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1_xnumel = s24*s34*s48*s9
285
+ stream0 = get_raw_stream(0)
286
+ triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1.run(tangents_1, primals_8, primals_6, primals_4, buf1, s24, s9, s79, s92, triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1_xnumel, stream=stream0)
287
+ del primals_4
288
+ del primals_6
289
+ del primals_8
290
+ del tangents_1
291
+ return (None, None, None, None, None, None, None, None, None, None, None, buf1, None, buf0, )
292
+
293
+ runner = Runner(partitions=[])
294
+ call = runner.call
295
+ recursively_apply_fns = runner.recursively_apply_fns
296
+
297
+
298
+ def benchmark_compiled_module(times=10, repeat=10):
299
+ from torch._dynamo.testing import rand_strided
300
+ from torch._inductor.utils import print_performance
301
+ primals_2 = 128
302
+ primals_7 = 2048
303
+ primals_10 = 2
304
+ primals_11 = 32
305
+ primals_13 = 8
306
+ primals_1 = 2048
307
+ primals_3 = 5245440
308
+ primals_5 = 2048
309
+ floordiv = 64
310
+ add_96 = 64
311
+ primals_4 = rand_strided((1, 1, 2048, 128), (5245440, 5245440, 128, 1), device='cuda:0', dtype=torch.bfloat16)
312
+ primals_6 = rand_strided((1, 1, 2048, 128), (5245440, 5245440, 128, 1), device='cuda:0', dtype=torch.bfloat16)
313
+ primals_8 = rand_strided((1, 2048), (2048, 1), device='cuda:0', dtype=torch.int64)
314
+ tangents_1 = rand_strided((2, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16)
315
+ tangents_2 = rand_strided((2, 8, 2048, 128), (2097152, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16)
316
+ fn = lambda: call([primals_2, primals_7, primals_10, primals_11, primals_13, primals_1, primals_3, primals_5, floordiv, add_96, primals_4, primals_6, primals_8, tangents_1, tangents_2])
317
+ return print_performance(fn, times=times, repeat=repeat)
318
+
319
+
320
+ if __name__ == "__main__":
321
+ from torch._inductor.wrapper_benchmark import compiled_module_main
322
+ compiled_module_main('None', benchmark_compiled_module)
SpecForge-ext/cache/compiled_kernels/c4/3fc868fcdc136a60cbcdc167284005fb6cd4078af5cf939debad2799d55dedad.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 8, "R0_BLOCK": 128, "num_warps": 8, "num_stages": 1, "configs_hash": "b70837e3723f218c7368cc2b49566dcd2bec3baf4c88b5e174a3f0822a6c86c0", "found_by_coordesc": false, "time_taken_ms": 142, "triton_cache_hash": "BZ2FPB5QIE7EHR6P7EPVPHR4HKS3YX3QQPIWQIT2R3EOJOAVWCGA"}
SpecForge-ext/cache/compiled_kernels/c4/cc44tmaxtaxohkbf52w5omwmrxhrmn6iuplipagv7rlnxaz6dkey.py ADDED
@@ -0,0 +1,552 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+
9
+ @triton_heuristics.template(
10
+
11
+ num_stages=3,
12
+ num_warps=8,
13
+ triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'in_ptr9': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
15
+
16
+ )
17
+ @triton.jit
18
+ def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5):
19
+ PRESCALE_QK : tl.constexpr = False
20
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
21
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
22
+ WRITE_DQ : tl.constexpr = True
23
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
24
+ OUTPUT_MAX : tl.constexpr = False
25
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
26
+ IS_DIVISIBLE : tl.constexpr = False
27
+ SM_SCALE : tl.constexpr = 0.08838834764831843
28
+ GQA_SHARED_HEADS : tl.constexpr = 4
29
+ HAS_FULL_BLOCKS : tl.constexpr = True
30
+ QK_HEAD_DIM : tl.constexpr = 128
31
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
32
+ V_HEAD_DIM : tl.constexpr = 128
33
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
34
+ SAFE_HEAD_DIM : tl.constexpr = True
35
+ USE_TMA : tl.constexpr = False
36
+ BLOCK_M : tl.constexpr = 128
37
+ BLOCK_N : tl.constexpr = 64
38
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
39
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
40
+ INDEX_DTYPE : tl.constexpr = tl.int32
41
+ Q = arg_Q
42
+ K = arg_K
43
+ V = arg_V
44
+ LSE = arg_LSE
45
+ MAX = arg_MAX
46
+ KV_NUM_BLKS = arg_KV_NUM_BLKS
47
+ KV_IDX = arg_KV_IDX
48
+ FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
49
+ FULL_KV_IDX = arg_FULL_KV_IDX
50
+
51
+ # Sub notation for this kernel:
52
+ #
53
+ # Q: Query, K: Key, V: Value
54
+ # M: Number of queries, N: Number of keys/values, D: Model dimension
55
+ # QK_HEAD_DIM: The dimension of the query and key embeddings
56
+ # V_HEAD_DIM: The dimension of the value embeddings
57
+ # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head
58
+ # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
59
+ #
60
+ # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
61
+ # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
62
+ # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
63
+ # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
64
+ # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
65
+ #
66
+ # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad
67
+ #
68
+ # (Modifiable) Performance tuning options
69
+ # BLOCK_M: The thread block size across the seqlen dim of Q.
70
+ # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block.
71
+
72
+ # The below are kernel options that can be applied for certain score_mods,
73
+ # or involve a numerics vs. perf tradeoff
74
+ # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
75
+ # about 20% more numerical error, but slightly faster.
76
+ # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row
77
+ # is not masked out? If so, we can skip an extra safety check
78
+ # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are
79
+ # contiguous? If so, we don't need to do an indirect jump for every block
80
+
81
+ tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0)
82
+ tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0)
83
+
84
+ # Define strides of inputs
85
+ stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1
86
+ stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128*ks1, 128, 1
87
+ stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128*ks1, 128, 1
88
+
89
+ ZQ = 8
90
+ HQ = 32
91
+ Q_LEN = ks0
92
+ ZKV = 8
93
+ KV_LEN = ks1
94
+
95
+ MATMUL_PRECISION = Q.dtype.element_ty
96
+
97
+ q_start = tl.program_id(0).to(INDEX_DTYPE)
98
+ off_zq = tl.program_id(1).to(INDEX_DTYPE)
99
+ off_hq = tl.program_id(2).to(INDEX_DTYPE)
100
+
101
+ # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq.
102
+ # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0.
103
+ off_zkv = off_zq % ZKV
104
+ off_hkv = off_hq // GQA_SHARED_HEADS
105
+ off_g = off_hq % GQA_SHARED_HEADS
106
+
107
+ q_offset = off_zq * stride_qz + off_hq * stride_qh
108
+ k_offset = off_zkv * stride_kz + off_hkv * stride_kh
109
+ v_offset = off_zkv * stride_vz + off_hkv * stride_vh
110
+
111
+ Q = Q + q_offset
112
+ K = K + k_offset
113
+ V = V + v_offset
114
+
115
+ # Setting up the TMA descriptors for Q, K, V
116
+ desc_q = None
117
+ desc_k = None
118
+ desc_v = None
119
+
120
+ SPARSE_Z = 8
121
+ SPARSE_HQ = 1
122
+
123
+ sparse_idx_z = off_zq % SPARSE_Z
124
+ sparse_idx_hq = off_hq % SPARSE_HQ
125
+
126
+ SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M)
127
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
128
+
129
+ stride_kv_num_blks_h = ks2
130
+ stride_kv_idx_h = ks3*ks4
131
+ stride_kv_idx_m = ks4
132
+
133
+ # initialize pointer to m and l
134
+ m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
135
+ l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
136
+ acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
137
+
138
+ offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
139
+
140
+ # KV_IDX and KV_NUM_BLKS are always contiguous.
141
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq
142
+ sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE
143
+ sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950
144
+ offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
145
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
146
+ q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
147
+
148
+ # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
149
+ # We don't know anything "special" about these blocks, so we need to apply
150
+ # both score_mod and mask_mod to it
151
+ kv_indices = KV_IDX + sparse_kv_idx_offset
152
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
153
+ kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
154
+ block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
155
+
156
+
157
+ # K and V pointers will be passed directly to forward_inner
158
+
159
+ offs_n = kv_start + tl.arange(0, BLOCK_N)
160
+
161
+
162
+ acc, l_i, m_i = forward_inner(
163
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5,
164
+ q, K, V,
165
+ desc_k, desc_v, Q_LEN, KV_LEN,
166
+ acc, l_i, m_i,
167
+ off_zq, off_hq, offs_m[:, None], offs_n[None, :],
168
+ kv_start,
169
+ kv_indices, kv_num_blocks,
170
+ 0, block_n_end,
171
+ MATMUL_PRECISION,
172
+ stride_kk, stride_kn, stride_vn, stride_vk,
173
+ IS_FULL_BLOCKS=False,
174
+ )
175
+
176
+ # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
177
+ # We know these blocks are guaranteed to be "full", so we don't need to
178
+ # apply mask_mod to them - only score_mod
179
+ if HAS_FULL_BLOCKS:
180
+ # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
181
+ kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
182
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
183
+ kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
184
+ block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
185
+ # K and V pointers will be passed directly to forward_inner
186
+ offs_n = kv_start + tl.arange(0, BLOCK_N)
187
+
188
+ acc, l_i, m_i = forward_inner(
189
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5,
190
+ q, K, V,
191
+ desc_k, desc_v, Q_LEN, KV_LEN,
192
+ acc, l_i, m_i,
193
+ off_zq, off_hq, offs_m[:, None], offs_n[None, :],
194
+ kv_start,
195
+ kv_indices, kv_num_blocks,
196
+ 0, block_n_end,
197
+ MATMUL_PRECISION,
198
+ stride_kk, stride_kn, stride_vn, stride_vk,
199
+ IS_FULL_BLOCKS=True,
200
+ )
201
+
202
+
203
+ # [Note] Handle fully masked out rows:
204
+ # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf.
205
+ # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step
206
+ l_i = tl.where(l_i == 0.0, 1, l_i)
207
+
208
+ acc = acc / l_i[:, None]
209
+ idx_zq = tl.program_id(1).to(INDEX_DTYPE)
210
+ idx_hq = tl.program_id(2).to(INDEX_DTYPE)
211
+ idx_m = offs_m[:, None].to(INDEX_DTYPE)
212
+ idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE)
213
+
214
+ mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM)
215
+
216
+ tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED])
217
+ xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0
218
+ tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m + 4096*idx_zq*ks0, acc.shape)), acc, mask)
219
+
220
+ if OUTPUT_LOGSUMEXP:
221
+ off_hz = off_zq * HQ + off_hq
222
+ l_ptrs = LSE + off_hz * Q_LEN + offs_m
223
+ lse = m_i + tl.math.log2(l_i)
224
+ if IS_DIVISIBLE:
225
+ tl.store(l_ptrs, lse)
226
+ else:
227
+ tl.store(l_ptrs, lse, mask=offs_m < Q_LEN)
228
+
229
+ if OUTPUT_MAX:
230
+ off_hz = off_zq * HQ + off_hq
231
+ max_ptrs = MAX + off_hz * Q_LEN + offs_m
232
+ if IS_DIVISIBLE:
233
+ tl.store(max_ptrs, m_i)
234
+ else:
235
+ tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN)
236
+
237
+
238
+ # Utility triton funcs
239
+ @triton.jit
240
+ def get_offset_for_next_block(
241
+ loop_iter, col_indices, total_blocks,
242
+ SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
243
+ BLOCKS_ARE_CONTIGUOUS: tl.constexpr
244
+ ):
245
+ if BLOCKS_ARE_CONTIGUOUS:
246
+ return BLOCK
247
+ cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
248
+ cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
249
+ next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
250
+ needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
251
+ jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
252
+ offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
253
+ return offset
254
+
255
+ @triton.jit
256
+ def get_bounded_indices(indices, max_len=None):
257
+ return indices % max_len if max_len is not None else indices
258
+
259
+ @triton.jit
260
+ def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
261
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
262
+ return tl.load(block_ptr)
263
+ elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
264
+ return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
265
+ elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
266
+ return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
267
+ else:
268
+ return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
269
+
270
+ @triton.jit
271
+ def load_checked_2d(
272
+ ptr,
273
+ offs_m,
274
+ offs_n,
275
+ stride_m,
276
+ stride_n,
277
+ IS_DIVISIBLE_M: tl.constexpr,
278
+ IS_DIVISIBLE_N: tl.constexpr,
279
+ M_LEN: tl.constexpr,
280
+ N_LEN: tl.constexpr,
281
+ ):
282
+ # Calculate final pointer if strides are provided
283
+ if stride_m is not None and stride_n is not None:
284
+ ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
285
+
286
+ # Handle all masking cases
287
+ if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
288
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
289
+ elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
290
+ return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
291
+ elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
292
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
293
+ else: # Both divisible
294
+ return tl.load(ptr)
295
+
296
+
297
+ # Common Imports
298
+ @triton.jit
299
+ def forward_block_mn(
300
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5,
301
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
302
+ # accumulated values
303
+ acc, l_i, m_i,
304
+ # Offsets
305
+ off_z, off_h, offs_m, offs_n,
306
+ # Offsets needed for TMA loads
307
+ kv_start,
308
+ kv_offset,
309
+ MATMUL_PRECISION, RCP_LN2,
310
+ # Strides for K and V
311
+ stride_kk, stride_kn, stride_vn, stride_vk,
312
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
313
+
314
+ ):
315
+ # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
316
+ PRESCALE_QK : tl.constexpr = False
317
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
318
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
319
+ WRITE_DQ : tl.constexpr = True
320
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
321
+ OUTPUT_MAX : tl.constexpr = False
322
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
323
+ IS_DIVISIBLE : tl.constexpr = False
324
+ SM_SCALE : tl.constexpr = 0.08838834764831843
325
+ GQA_SHARED_HEADS : tl.constexpr = 4
326
+ HAS_FULL_BLOCKS : tl.constexpr = True
327
+ QK_HEAD_DIM : tl.constexpr = 128
328
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
329
+ V_HEAD_DIM : tl.constexpr = 128
330
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
331
+ SAFE_HEAD_DIM : tl.constexpr = True
332
+ USE_TMA : tl.constexpr = False
333
+ BLOCK_M : tl.constexpr = 128
334
+ BLOCK_N : tl.constexpr = 64
335
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
336
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
337
+ INDEX_DTYPE : tl.constexpr = tl.int32
338
+
339
+
340
+ # -- load k --
341
+ # NB reversed order to since K is transposed
342
+ kv_base_offset = kv_start + kv_offset
343
+
344
+ # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N]
345
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
346
+ offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N)
347
+ k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
348
+
349
+ k = tl.trans(k)
350
+ # -- compute qk ---
351
+ qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2.
352
+ if not PRESCALE_QK:
353
+ qk *= SM_SCALE
354
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
355
+ # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements,
356
+ # which is larger than the actual number of elements. To avoid access memory out of bound,
357
+ # we need to mask out the elements that are out of Q_LEN & KV_LEN.
358
+ m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None)
359
+ n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None)
360
+
361
+ tmp0 = (qk)
362
+ post_mod_scores = tmp0
363
+
364
+
365
+ if CHECK_BLOCK_BOUNDARY:
366
+ # Mask out the elements that are out of the KV_LEN for non divisible seqlen.
367
+ post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf"))
368
+
369
+ if not IS_FULL_BLOCKS:
370
+ tmp1 = tl.full([1], False, tl.int1)
371
+ tmp2 = (m)
372
+ tmp3 = (n)
373
+ tmp4 = tmp2 >= tmp3
374
+ tmp5 = tmp3.to(tl.int64)
375
+ tmp6 = (off_z)
376
+ tmp7 = tl.load(in_ptr9 + tmp6)
377
+ tmp8 = tmp5 < tmp7
378
+ tmp9 = tmp2.to(tl.int64)
379
+ tmp10 = tmp9 < tmp7
380
+ tmp11 = tmp8 & tmp10
381
+ tmp12 = tmp4 & tmp11
382
+ tmp13 = tmp1 | tmp12
383
+ tmp14 = ks5
384
+ tmp15 = tmp3 >= tmp14
385
+ tmp16 = (tmp3 % tmp14)
386
+ tmp17 = tl.full([1], 0, tl.int32)
387
+ tmp18 = tmp16 != tmp17
388
+ tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0
389
+ tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0
390
+ tmp21 = tmp19 != tmp20
391
+ tmp22 = tmp18 & tmp21
392
+ tmp23 = tmp16 + tmp14
393
+ tmp24 = tl.where(tmp22, tmp23, tmp16)
394
+ tmp25 = tmp24.to(tl.int64)
395
+ tmp26 = tmp25 < tmp7
396
+ tmp27 = tmp15 & tmp26
397
+ tmp28 = tmp3 - tmp2
398
+ tmp29 = (tmp28 % tmp14)
399
+ tmp30 = tmp29 != tmp17
400
+ tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0
401
+ tmp32 = tmp31 != tmp20
402
+ tmp33 = tmp30 & tmp32
403
+ tmp34 = tmp29 + tmp14
404
+ tmp35 = tl.where(tmp33, tmp34, tmp29)
405
+ tmp36 = tmp35 == tmp17
406
+ tmp37 = tmp27 & tmp36
407
+ tmp38 = tmp13 | tmp37
408
+ mask_mod_output = tmp38
409
+
410
+
411
+ if CHECK_BLOCK_BOUNDARY:
412
+ mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False)
413
+ # apply mask for partially unmasked blocks
414
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
415
+
416
+ if not PRESCALE_QK:
417
+ post_mod_scores *= RCP_LN2
418
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
419
+
420
+ # -- compute scaling constant ---
421
+ m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1))
422
+ if not ROWS_GUARANTEED_SAFE:
423
+ masked_out_rows = (m_ij == float("-inf"))
424
+ m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
425
+ else:
426
+ m_ij_masked = m_ij
427
+
428
+ alpha = tl.math.exp2(m_i - m_ij_masked)
429
+ p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None])
430
+
431
+ # NB: l_i update is pulled up here since it's a bit faster
432
+ # NB: For headdim=256, it's faster to move it back down to after m_i =
433
+ # m_ij
434
+ l_i = l_i * alpha + tl.sum(p, 1)
435
+ # # -- scale and update acc --
436
+ acc = acc * alpha[:, None]
437
+ # Calculate offsets for V loading - reuse kv_base_offset from K loading
438
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
439
+ v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
440
+ acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION)
441
+
442
+ # -- update m_i
443
+ m_i = m_ij
444
+
445
+ return acc, l_i, m_i
446
+
447
+ @triton.jit
448
+ def forward_inner(
449
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5,
450
+ q, K, V,
451
+ desc_k, desc_v, Q_LEN, KV_LEN,
452
+ # accumulated values
453
+ acc, l_i, m_i,
454
+ # Offsets used as inputs to score_mod & mask_mod
455
+ # of size [BLOCK_M, BLOCK_N] or scalar.
456
+ off_z, off_h, offs_m, offs_n,
457
+ # Offsets needed for TMA loads
458
+ kv_start,
459
+ # blocksparse data
460
+ kv_indices, kv_num_blocks,
461
+ # start kv and end kv block
462
+ block_n_start, block_n_end,
463
+ MATMUL_PRECISION,
464
+ # Strides for K and V
465
+ stride_kk, stride_kn, stride_vn, stride_vk,
466
+ IS_FULL_BLOCKS,
467
+ ):
468
+ # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
469
+ PRESCALE_QK : tl.constexpr = False
470
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
471
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
472
+ WRITE_DQ : tl.constexpr = True
473
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
474
+ OUTPUT_MAX : tl.constexpr = False
475
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
476
+ IS_DIVISIBLE : tl.constexpr = False
477
+ SM_SCALE : tl.constexpr = 0.08838834764831843
478
+ GQA_SHARED_HEADS : tl.constexpr = 4
479
+ HAS_FULL_BLOCKS : tl.constexpr = True
480
+ QK_HEAD_DIM : tl.constexpr = 128
481
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
482
+ V_HEAD_DIM : tl.constexpr = 128
483
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
484
+ SAFE_HEAD_DIM : tl.constexpr = True
485
+ USE_TMA : tl.constexpr = False
486
+ BLOCK_M : tl.constexpr = 128
487
+ BLOCK_N : tl.constexpr = 64
488
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
489
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
490
+ INDEX_DTYPE : tl.constexpr = tl.int32
491
+
492
+
493
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
494
+ RCP_LN2: tl.constexpr = 1.44269504
495
+
496
+ if PRESCALE_QK:
497
+ q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
498
+
499
+ kv_offset = 0
500
+
501
+ # loop over k, v and update accumulator until block_n_end
502
+ for start_n in range(block_n_start, block_n_end):
503
+ # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention.
504
+ if IS_DIVISIBLE:
505
+ acc, l_i, m_i = forward_block_mn(
506
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5,
507
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
508
+ # accumulated values
509
+ acc, l_i, m_i,
510
+ # Offsets
511
+ off_z, off_h, offs_m, offs_n,
512
+ # Offsets needed for TMA loads
513
+ kv_start,
514
+ kv_offset,
515
+ MATMUL_PRECISION, RCP_LN2,
516
+ # Strides for K and V
517
+ stride_kk, stride_kn, stride_vn, stride_vk,
518
+ IS_FULL_BLOCKS,
519
+ )
520
+ else:
521
+ # Benchmark shows even we applied mod & mask to each block for non divisible seqlen,
522
+ # it's on par or slightly faster than only applying to the last block in fwd.
523
+ # However, we choose different strategy for bwd, where we only apply mod & mask
524
+ # to the last block because it's faster a lot.
525
+ acc, l_i, m_i = forward_block_mn(
526
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5,
527
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
528
+ # accumulated values
529
+ acc, l_i, m_i,
530
+ # Offsets
531
+ off_z, off_h, offs_m, offs_n,
532
+ # Offsets needed for TMA loads
533
+ kv_start,
534
+ kv_offset,
535
+ MATMUL_PRECISION, RCP_LN2,
536
+ # Strides for K and V
537
+ stride_kk, stride_kn, stride_vn, stride_vk,
538
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
539
+ )
540
+
541
+
542
+
543
+ offset = get_offset_for_next_block(
544
+ start_n, kv_indices, kv_num_blocks,
545
+ SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS
546
+ )
547
+
548
+ offs_n = offs_n + offset
549
+ kv_offset += offset
550
+
551
+
552
+ return acc, l_i, m_i
SpecForge-ext/cache/compiled_kernels/c4/cc4r2l3x4dfli5iih5dji2abfxoclfozqdaqfbdxtcf6lqfpqwdo.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.reduction(
11
+ size_hints={'x': 524288, 'r0_': 128},
12
+ reduction_hint=ReductionHint.DEFAULT,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr1': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 4194304, 'r0_': 268435456}}
16
+ )
17
+ @triton.jit
18
+ def triton_red_fused_zeros_0(in_ptr0, in_ptr1, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
19
+ xnumel = 524288
20
+ r0_numel = 128
21
+ rnumel = r0_numel
22
+ RBLOCK: tl.constexpr = R0_BLOCK
23
+ xoffset = tl.program_id(0) * XBLOCK
24
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
25
+ xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
26
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
27
+ rbase = r0_base
28
+ x0 = (xindex % 2048)
29
+ x1 = ((xindex // 2048) % 32)
30
+ x2 = xindex // 65536
31
+ x4 = xindex
32
+ _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
33
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
34
+ r0_index = r0_offset + r0_base
35
+ r0_mask = r0_index < r0_numel
36
+ roffset = r0_offset
37
+ rindex = r0_index
38
+ r0_3 = r0_index
39
+ tmp0 = tl.load(in_ptr0 + (r0_3 + 128*x1 + 4096*x0 + 8388608*x2), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
40
+ tmp1 = tl.load(in_ptr1 + (r0_3 + 128*x4), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
41
+ tmp2 = tmp0 * tmp1
42
+ tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK])
43
+ tmp5 = _tmp4 + tmp3
44
+ _tmp4 = tl.where(r0_mask, tmp5, _tmp4)
45
+ tmp4 = tl.sum(_tmp4, 1)[:, None]
46
+ tmp6 = tmp4.to(tl.float32)
47
+ tmp7 = 0.0
48
+ tmp8 = tmp6 - tmp7
49
+ tl.store(out_ptr1 + (x4), tmp8, None)
SpecForge-ext/cache/compiled_kernels/cm/ccmqky4m65yifqjmfuu7vgvpuhwpa4ybaxffiy3mu2e6yzgecghe.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.pointwise(
11
+ size_hints={'x': 1024},
12
+ filename=__file__,
13
+ triton_meta={'signature': {'out_ptr0': '*i32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_new_zeros_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 4352}},
15
+ min_elem_per_thread=0
16
+ )
17
+ @triton.jit
18
+ def triton_poi_fused_new_zeros_1(out_ptr0, xnumel, XBLOCK : tl.constexpr):
19
+ xnumel = 544
20
+ xoffset = tl.program_id(0) * XBLOCK
21
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
22
+ xmask = xindex < xnumel
23
+ x0 = xindex
24
+ tmp0 = tl.full([1], 0, tl.int32)
25
+ tl.store(out_ptr0 + (x0), tmp0, xmask)
SpecForge-ext/cache/compiled_kernels/dd/cddrh2oo46t7tins6cvtu23g2titlwclg4aile7eli326p7we42m.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AOT ID: ['11_inference']
2
+ from ctypes import c_void_p, c_long, c_int
3
+ import torch
4
+ import math
5
+ import random
6
+ import os
7
+ import tempfile
8
+ from math import inf, nan
9
+ from cmath import nanj
10
+ from torch._inductor.hooks import run_intermediate_hooks
11
+ from torch._inductor.utils import maybe_profile
12
+ from torch._inductor.codegen.memory_planning import _align as align
13
+ from torch import device, empty_strided
14
+ from torch._inductor.async_compile import AsyncCompile
15
+ from torch._inductor.select_algorithm import extern_kernels
16
+ import triton
17
+ import triton.language as tl
18
+ from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
19
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
20
+
21
+ aten = torch.ops.aten
22
+ inductor_ops = torch.ops.inductor
23
+ _quantized = torch.ops._quantized
24
+ assert_size_stride = torch._C._dynamo.guards.assert_size_stride
25
+ assert_alignment = torch._C._dynamo.guards.assert_alignment
26
+ empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
27
+ empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
28
+ empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
29
+ empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
30
+ empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
31
+ reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
32
+ alloc_from_pool = torch.ops.inductor._alloc_from_pool
33
+ async_compile = AsyncCompile()
34
+ empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
35
+
36
+
37
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ks/cksdatp7sjl5kfr5pxvwrbjelhvz35c35rvym5wgbvhrovwd5isa.py
38
+ # Topologically Sorted Source Nodes: [target_head, target_p], Original ATen: [aten._to_copy, prims.prepare_softmax_online, aten.sub, aten.exp, aten._softmax]
39
+ # Source node to ATen node mapping:
40
+ # target_head => convert_element_type
41
+ # target_p => div
42
+ # Graph fragment:
43
+ # %arg1_1 : Tensor "bf16[8, s67, 32000][32000*s67, 32000, 1]cuda:4" = PlaceHolder[target=arg1_1]
44
+ # %getitem : Tensor "f32[8, s67, 1][s67, 1, 8*s67]cuda:4" = PlaceHolder[target=getitem]
45
+ # %getitem_1 : Tensor "f32[8, s67, 1][s67, 1, 8*s67]cuda:4" = PlaceHolder[target=getitem_1]
46
+ # %convert_element_type : Tensor "f32[8, s67, 32000][32000*s67, 32000, 1]cuda:4"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%arg1_1, torch.float32), kwargs = {})
47
+ # %prepare_softmax_online_default : [num_users=2] = call_function[target=torch.ops.prims.prepare_softmax_online.default](args = (%convert_element_type, 2), kwargs = {})
48
+ # %sub_tensor : Tensor "f32[8, s67, 32000][32000*s67, 32000, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%convert_element_type, %getitem), kwargs = {})
49
+ # %exp_default : Tensor "f32[8, s67, 32000][32000*s67, 32000, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.exp.default](args = (%sub_tensor,), kwargs = {})
50
+ # %div : Tensor "f32[8, s67, 32000][32000*s67, 32000, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%exp_default, %getitem_1), kwargs = {})
51
+ # return %getitem,%getitem_1,%div
52
+ triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0 = async_compile.triton('triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0', '''
53
+ import triton
54
+ import triton.language as tl
55
+
56
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
57
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
58
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
59
+ triton_helpers.set_driver_to_gpu()
60
+
61
+ @triton_heuristics.reduction(
62
+ size_hints={'x': 16384, 'r0_': 32768},
63
+ reduction_hint=ReductionHint.INNER,
64
+ filename=__file__,
65
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr2': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
66
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
67
+ )
68
+ @triton.jit
69
+ def triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0(in_ptr0, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
70
+ r0_numel = 32000
71
+ rnumel = r0_numel
72
+ RBLOCK: tl.constexpr = R0_BLOCK
73
+ xoffset = tl.program_id(0) * XBLOCK
74
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
75
+ xmask = xindex < xnumel
76
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
77
+ rbase = r0_base
78
+ x0 = xindex
79
+ _tmp3_max = tl.full([XBLOCK, R0_BLOCK], float('-inf'), tl.float32)
80
+ _tmp3_sum = tl.zeros([XBLOCK, R0_BLOCK], tl.float32)
81
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
82
+ r0_index = r0_offset + r0_base
83
+ r0_mask = r0_index < r0_numel
84
+ roffset = r0_offset
85
+ rindex = r0_index
86
+ r0_1 = r0_index
87
+ tmp0 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
88
+ tmp1 = tmp0.to(tl.float32)
89
+ tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK])
90
+
91
+ _tmp3_max_next, _tmp3_sum_next = triton_helpers.online_softmax_combine(
92
+ _tmp3_max, _tmp3_sum, tmp2, False
93
+ )
94
+
95
+ _tmp3_max = tl.where(r0_mask & xmask, _tmp3_max_next, _tmp3_max)
96
+ _tmp3_sum = tl.where(r0_mask & xmask, _tmp3_sum_next, _tmp3_sum)
97
+
98
+ tmp3, tmp4 = triton_helpers.online_softmax_reduce(
99
+ _tmp3_max, _tmp3_sum, 1, False)
100
+ tmp3 = tmp3[:, None]
101
+ tmp4 = tmp4[:, None]
102
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
103
+ r0_index = r0_offset + r0_base
104
+ r0_mask = r0_index < r0_numel
105
+ roffset = r0_offset
106
+ rindex = r0_index
107
+ r0_1 = r0_index
108
+ tmp5 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
109
+ tmp6 = tmp5.to(tl.float32)
110
+ tmp7 = tmp6 - tmp3
111
+ tmp8 = libdevice.exp(tmp7)
112
+ tmp9 = (tmp8 / tmp4)
113
+ tl.store(out_ptr2 + (r0_1 + 32000*x0), tmp9, r0_mask & xmask)
114
+ ''', device_str='cuda')
115
+
116
+
117
+ async_compile.wait(globals())
118
+ del async_compile
119
+
120
+ class Runner:
121
+ def __init__(self, partitions):
122
+ self.partitions = partitions
123
+
124
+ def recursively_apply_fns(self, fns):
125
+ new_callables = []
126
+ for fn, c in zip(fns, self.partitions):
127
+ new_callables.append(fn(c))
128
+ self.partitions = new_callables
129
+
130
+ def call(self, args):
131
+ arg0_1, arg1_1 = args
132
+ args.clear()
133
+ s67 = arg0_1
134
+ assert_size_stride(arg1_1, (8, s67, 32000), (32000*s67, 32000, 1))
135
+ with torch.cuda._DeviceGuard(4):
136
+ torch.cuda.set_device(4)
137
+ buf2 = empty_strided_cuda((8, s67, 32000), (32000*s67, 32000, 1), torch.float32)
138
+ # Topologically Sorted Source Nodes: [target_head, target_p], Original ATen: [aten._to_copy, prims.prepare_softmax_online, aten.sub, aten.exp, aten._softmax]
139
+ triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0_xnumel = 8*s67
140
+ stream4 = get_raw_stream(4)
141
+ triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0.run(arg1_1, buf2, triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0_xnumel, 32000, stream=stream4)
142
+ del arg1_1
143
+ return (buf2, )
144
+
145
+ runner = Runner(partitions=[])
146
+ call = runner.call
147
+ recursively_apply_fns = runner.recursively_apply_fns
148
+
149
+
150
+ def benchmark_compiled_module(times=10, repeat=10):
151
+ from torch._dynamo.testing import rand_strided
152
+ from torch._inductor.utils import print_performance
153
+ arg0_1 = 1896
154
+ arg1_1 = rand_strided((8, 1896, 32000), (60672000, 32000, 1), device='cuda:4', dtype=torch.bfloat16)
155
+ fn = lambda: call([arg0_1, arg1_1])
156
+ return print_performance(fn, times=times, repeat=repeat)
157
+
158
+
159
+ if __name__ == "__main__":
160
+ from torch._inductor.wrapper_benchmark import compiled_module_main
161
+ compiled_module_main('None', benchmark_compiled_module)
SpecForge-ext/cache/compiled_kernels/dl/b1f7dcc79c7c02fa44a9647ad7a02640f8312b36f97c27e92cc10dbab8e47d63.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 1, "num_warps": 2, "num_stages": 1, "configs_hash": "6fcabd0411a839b7b5d117b5e6638bd1b5d7bc3379312c678d803859f08278a9", "found_by_coordesc": false, "time_taken_ms": 28, "triton_cache_hash": "NFXDDTZIEQAGR5JBWWCXV73QZILXJMIVVJVPZH3CHAIULDAND5UQ"}
SpecForge-ext/cache/compiled_kernels/dl/cdlmoxz5rmtmnvhkkdtgykahwdzntxp2vrhxdea2s6finrwqdeut.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.persistent_reduction(
11
+ size_hints={'x': 32, 'r0_': 16},
12
+ reduction_hint=ReductionHint.DEFAULT,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr4': '*i32', 'out_ptr5': '*i32', 'out_ptr6': '*i32', 'out_ptr7': '*i32', 'out_ptr8': '*i32', 'out_ptr9': '*i32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2', 'mutated_arg_names': ['out_ptr7', 'out_ptr9'], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
16
+ )
17
+ @triton.jit
18
+ def triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2(in_ptr0, out_ptr4, out_ptr5, out_ptr6, out_ptr7, out_ptr8, out_ptr9, xnumel, r0_numel, XBLOCK : tl.constexpr):
19
+ xnumel = 32
20
+ r0_numel = 16
21
+ R0_BLOCK: tl.constexpr = 16
22
+ rnumel = r0_numel
23
+ RBLOCK: tl.constexpr = R0_BLOCK
24
+ xoffset = tl.program_id(0) * XBLOCK
25
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
26
+ xmask = xindex < xnumel
27
+ r0_index = tl.arange(0, R0_BLOCK)[None, :]
28
+ r0_offset = 0
29
+ r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
30
+ roffset = r0_offset
31
+ rindex = r0_index
32
+ r0_1 = r0_index
33
+ x0 = xindex
34
+ tmp0 = tl.load(in_ptr0 + (r0_1 + 16*x0), xmask, other=0.0)
35
+ tmp1 = tl.full([1, 1], 0, tl.int64)
36
+ tmp2 = tmp0 > tmp1
37
+ tmp3 = tl.full([1, 1], 16384, tl.int64)
38
+ tmp4 = tmp0 < tmp3
39
+ tmp5 = tmp2 & tmp4
40
+ tmp6 = tmp5.to(tl.int8)
41
+ tmp7 = tmp6.to(tl.int32)
42
+ tmp8 = r0_1
43
+ tmp9 = tmp8.to(tl.int16)
44
+ tmp10 = tl.broadcast_to(tmp7, [XBLOCK, R0_BLOCK])
45
+ tmp11 = tl.broadcast_to(tmp9, [XBLOCK, R0_BLOCK])
46
+ tmp12, tmp13, = triton_helpers.sort_with_index(tmp10, tmp11, None, 1, stable=True, descending=True)
47
+ tmp14 = tmp0 == tmp3
48
+ tmp15 = tmp14.to(tl.int8)
49
+ tmp16 = tmp15.to(tl.int32)
50
+ tmp17 = tl.broadcast_to(tmp16, [XBLOCK, R0_BLOCK])
51
+ tmp18, tmp19, = triton_helpers.sort_with_index(tmp17, tmp11, None, 1, stable=True, descending=True)
52
+ tmp20 = tmp7.to(tl.int64)
53
+ tmp21 = tl.broadcast_to(tmp20, [XBLOCK, R0_BLOCK])
54
+ tmp23 = tl.where(xmask, tmp21, 0)
55
+ tmp24 = tl.sum(tmp23, 1)[:, None].to(tl.int64)
56
+ tmp25 = tmp16.to(tl.int64)
57
+ tmp26 = tl.broadcast_to(tmp25, [XBLOCK, R0_BLOCK])
58
+ tmp28 = tl.where(xmask, tmp26, 0)
59
+ tmp29 = tl.sum(tmp28, 1)[:, None].to(tl.int64)
60
+ tmp30 = tmp24.to(tl.int32)
61
+ tmp31 = tmp29.to(tl.int32)
62
+ tmp32 = tmp13.to(tl.int64)
63
+ tmp33 = tmp32.to(tl.int32)
64
+ tmp34 = tmp8 < tmp30
65
+ tmp35 = tl.full([1, 1], 16, tl.int32)
66
+ tmp36 = tl.where(tmp34, tmp33, tmp35)
67
+ tmp37 = tl.full([XBLOCK, R0_BLOCK], 17, tl.int32)
68
+ tmp38 = tmp36 + tmp37
69
+ tmp39 = tmp36 < 0
70
+ tmp40 = tl.where(tmp39, tmp38, tmp36)
71
+ tl.device_assert(((0 <= tmp40) & (tmp40 < 17)) | ~(xmask), "index out of bounds: 0 <= tmp40 < 17")
72
+ tmp42 = tl.full([1, 1], 1, tl.int32)
73
+ tmp43 = tmp19.to(tl.int64)
74
+ tmp44 = tmp43.to(tl.int32)
75
+ tmp45 = tmp8 < tmp31
76
+ tmp46 = tl.where(tmp45, tmp44, tmp35)
77
+ tmp47 = tmp46 + tmp37
78
+ tmp48 = tmp46 < 0
79
+ tmp49 = tl.where(tmp48, tmp47, tmp46)
80
+ tl.device_assert(((0 <= tmp49) & (tmp49 < 17)) | ~(xmask), "index out of bounds: 0 <= tmp49 < 17")
81
+ tl.store(out_ptr4 + (x0), tmp30, xmask)
82
+ tl.store(out_ptr5 + (x0), tmp31, xmask)
83
+ tl.store(out_ptr6 + (r0_1 + 16*x0), tmp33, xmask)
84
+ tl.store(out_ptr7 + (tl.broadcast_to(tmp40 + 17*x0, [XBLOCK, R0_BLOCK])), tmp42, xmask)
85
+ tl.store(out_ptr8 + (r0_1 + 16*x0), tmp44, xmask)
86
+ tl.store(out_ptr9 + (tl.broadcast_to(tmp49 + 17*x0, [XBLOCK, R0_BLOCK])), tmp42, xmask)
SpecForge-ext/cache/compiled_kernels/do/cdoarqsgem4ej5qjlp6zd22rf6fimpoonczzpmfv63um26txbfab.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AOT ID: ['10_inference']
2
+ from ctypes import c_void_p, c_long, c_int
3
+ import torch
4
+ import math
5
+ import random
6
+ import os
7
+ import tempfile
8
+ from math import inf, nan
9
+ from cmath import nanj
10
+ from torch._inductor.hooks import run_intermediate_hooks
11
+ from torch._inductor.utils import maybe_profile
12
+ from torch._inductor.codegen.memory_planning import _align as align
13
+ from torch import device, empty_strided
14
+ from torch._inductor.async_compile import AsyncCompile
15
+ from torch._inductor.select_algorithm import extern_kernels
16
+ import triton
17
+ import triton.language as tl
18
+ from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
19
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
20
+
21
+ aten = torch.ops.aten
22
+ inductor_ops = torch.ops.inductor
23
+ _quantized = torch.ops._quantized
24
+ assert_size_stride = torch._C._dynamo.guards.assert_size_stride
25
+ assert_alignment = torch._C._dynamo.guards.assert_alignment
26
+ empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
27
+ empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
28
+ empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
29
+ empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
30
+ empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
31
+ reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
32
+ alloc_from_pool = torch.ops.inductor._alloc_from_pool
33
+ async_compile = AsyncCompile()
34
+ empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
35
+
36
+
37
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xp/cxprnl6wyrkxecwymb5nwdyyiuq4vpew4zlpdy2zpq7whdmm3twe.py
38
+ # Topologically Sorted Source Nodes: [target_max_token, target_mask, getitem_1, target_mask_1, position_mask], Original ATen: [aten.argmax, aten.index, aten.unsqueeze, aten._to_copy, aten.mul]
39
+ # Source node to ATen node mapping:
40
+ # getitem_1 => unsqueeze
41
+ # position_mask => mul_2
42
+ # target_mask => index
43
+ # target_mask_1 => convert_element_type
44
+ # target_max_token => argmax
45
+ # Graph fragment:
46
+ # %arg1_1 : Tensor "bf16[2, s14, 151936][151936*s14, 151936, 1]cuda:7" = PlaceHolder[target=arg1_1]
47
+ # %argmax : Tensor "i64[2, s14][s14, 1]cuda:7" = PlaceHolder[target=argmax]
48
+ # %arg2_1 : Tensor "b8[151936][1]cuda:7" = PlaceHolder[target=arg2_1]
49
+ # %arg3_1 : Tensor "i64[2, s14, 1][s14, 1, 1]cuda:7" = PlaceHolder[target=arg3_1]
50
+ # %argmax : Tensor "i64[2, s14][s14, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.argmax.default](args = (%arg1_1, -1), kwargs = {})
51
+ # %index : Tensor "b8[2, s14][s14, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg2_1, [%argmax]), kwargs = {})
52
+ # %unsqueeze : Tensor "b8[2, s14, 1][s14, 1, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index, 2), kwargs = {})
53
+ # %convert_element_type : Tensor "i32[2, s14, 1][s14, 1, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%unsqueeze, torch.int32), kwargs = {})
54
+ # %mul_2 : Tensor "i64[2, s14, 1][s14, 1, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type, %arg3_1), kwargs = {})
55
+ # return %argmax,%mul_2
56
+ triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0 = async_compile.triton('triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0', '''
57
+ import triton
58
+ import triton.language as tl
59
+
60
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
61
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
62
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
63
+ triton_helpers.set_driver_to_gpu()
64
+
65
+ @triton_heuristics.reduction(
66
+ size_hints={'x': 4096, 'r0_': 262144},
67
+ reduction_hint=ReductionHint.INNER,
68
+ filename=__file__,
69
+ triton_meta={'signature': {'in_out_ptr0': '*i64', 'in_ptr0': '*bf16', 'in_ptr1': '*i1', 'in_ptr2': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]},
70
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
71
+ )
72
+ @triton.jit
73
+ def triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
74
+ r0_numel = 151936
75
+ rnumel = r0_numel
76
+ RBLOCK: tl.constexpr = R0_BLOCK
77
+ xoffset = tl.program_id(0) * XBLOCK
78
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
79
+ xmask = xindex < xnumel
80
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
81
+ rbase = r0_base
82
+ x0 = xindex
83
+ _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32)
84
+ _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32)
85
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
86
+ r0_index = r0_offset + r0_base
87
+ r0_mask = r0_index < r0_numel
88
+ roffset = r0_offset
89
+ rindex = r0_index
90
+ r0_1 = r0_index
91
+ tmp0 = tl.load(in_ptr0 + (r0_1 + 151936*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
92
+ tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
93
+ _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index(
94
+ _tmp2, _tmp2_index, tmp1, rindex
95
+ )
96
+ _tmp2 = tl.where(r0_mask & xmask, _tmp2_next, _tmp2)
97
+ _tmp2_index = tl.where(r0_mask & xmask, _tmp2_index_next, _tmp2_index)
98
+ tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1)
99
+ tmp2 = tmp2_idx[:, None]
100
+ tmp11 = tl.load(in_ptr2 + (x0), xmask, eviction_policy='evict_last')
101
+ tmp3 = tl.full([XBLOCK, 1], 151936, tl.int32)
102
+ tmp4 = tmp2 + tmp3
103
+ tmp5 = tmp2 < 0
104
+ tmp6 = tl.where(tmp5, tmp4, tmp2)
105
+ tl.device_assert(((0 <= tmp6) & (tmp6 < 151936)) | ~(xmask), "index out of bounds: 0 <= tmp6 < 151936")
106
+ tmp8 = tl.load(in_ptr1 + (tmp6), xmask, eviction_policy='evict_last').to(tl.int1)
107
+ tmp9 = tmp8.to(tl.int32)
108
+ tmp10 = tmp9.to(tl.int64)
109
+ tmp12 = tmp10 * tmp11
110
+ tl.debug_barrier()
111
+ tl.store(in_out_ptr0 + (x0), tmp12, xmask)
112
+ ''', device_str='cuda')
113
+
114
+
115
+ async_compile.wait(globals())
116
+ del async_compile
117
+
118
+ class Runner:
119
+ def __init__(self, partitions):
120
+ self.partitions = partitions
121
+
122
+ def recursively_apply_fns(self, fns):
123
+ new_callables = []
124
+ for fn, c in zip(fns, self.partitions):
125
+ new_callables.append(fn(c))
126
+ self.partitions = new_callables
127
+
128
+ def call(self, args):
129
+ arg0_1, arg1_1, arg2_1, arg3_1 = args
130
+ args.clear()
131
+ s24 = arg0_1
132
+ arg1_1_size = arg1_1.size()
133
+ s14 = arg1_1_size[1]
134
+ assert_size_stride(arg1_1, (2, s14, 151936), (151936*s14, 151936, 1))
135
+ assert_size_stride(arg2_1, (151936, ), (1, ))
136
+ assert_size_stride(arg3_1, (2, s14, 1), (s14, 1, 1))
137
+ with torch.cuda._DeviceGuard(7):
138
+ torch.cuda.set_device(7)
139
+ buf0 = empty_strided_cuda((2, s14), (s14, 1), torch.int64)
140
+ buf1 = reinterpret_tensor(buf0, (2, s14, 1), (s14, 1, 1), 0); del buf0 # reuse
141
+ # Topologically Sorted Source Nodes: [target_max_token, target_mask, getitem_1, target_mask_1, position_mask], Original ATen: [aten.argmax, aten.index, aten.unsqueeze, aten._to_copy, aten.mul]
142
+ triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0_xnumel = 2*s14
143
+ stream7 = get_raw_stream(7)
144
+ triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0.run(buf1, arg1_1, arg2_1, arg3_1, triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0_xnumel, 151936, stream=stream7)
145
+ del arg1_1
146
+ del arg2_1
147
+ del arg3_1
148
+ return (buf1, )
149
+
150
+ runner = Runner(partitions=[])
151
+ call = runner.call
152
+ recursively_apply_fns = runner.recursively_apply_fns
153
+
154
+
155
+ def benchmark_compiled_module(times=10, repeat=10):
156
+ from torch._dynamo.testing import rand_strided
157
+ from torch._inductor.utils import print_performance
158
+ arg0_1 = 1904
159
+ arg1_1 = rand_strided((2, 1904, 151936), (289286144, 151936, 1), device='cuda:7', dtype=torch.bfloat16)
160
+ arg2_1 = rand_strided((151936, ), (1, ), device='cuda:7', dtype=torch.bool)
161
+ arg3_1 = rand_strided((2, 1904, 1), (1904, 1, 1), device='cuda:7', dtype=torch.int64)
162
+ fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1])
163
+ return print_performance(fn, times=times, repeat=repeat)
164
+
165
+
166
+ if __name__ == "__main__":
167
+ from torch._inductor.wrapper_benchmark import compiled_module_main
168
+ compiled_module_main('None', benchmark_compiled_module)
SpecForge-ext/cache/compiled_kernels/dq/bbb4d7862e75b16b3f47ca1a7d19d9cb4b2d5337c27f7396cb01891263c9b13a.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 8, "R0_BLOCK": 128, "num_warps": 8, "num_stages": 1, "configs_hash": "b70837e3723f218c7368cc2b49566dcd2bec3baf4c88b5e174a3f0822a6c86c0", "found_by_coordesc": false, "time_taken_ms": 142, "triton_cache_hash": "BZ2FPB5QIE7EHR6P7EPVPHR4HKS3YX3QQPIWQIT2R3EOJOAVWCGA"}
SpecForge-ext/cache/compiled_kernels/dq/cdq6jyounnaz2w4x6s5oljefpge3fzx66pi3x25iwcuc6vazkfx6.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.persistent_reduction(
11
+ size_hints={'x': 64, 'r0_': 16},
12
+ reduction_hint=ReductionHint.DEFAULT,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr2': '*i32', 'out_ptr3': '*i32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
16
+ )
17
+ @triton.jit
18
+ def triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3(in_ptr0, out_ptr2, out_ptr3, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr):
19
+ r0_numel = 16
20
+ R0_BLOCK: tl.constexpr = 16
21
+ rnumel = r0_numel
22
+ RBLOCK: tl.constexpr = R0_BLOCK
23
+ xoffset = tl.program_id(0) * XBLOCK
24
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
25
+ xmask = xindex < xnumel
26
+ r0_index = tl.arange(0, R0_BLOCK)[None, :]
27
+ r0_offset = 0
28
+ r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
29
+ roffset = r0_offset
30
+ rindex = r0_index
31
+ r0_2 = r0_index
32
+ x0 = (xindex % ks0)
33
+ x1 = xindex // ks0
34
+ x3 = xindex
35
+ tmp0 = tl.load(in_ptr0 + (r0_2 + x0 + 16*x1 + ks0*r0_2 + 16*ks0*x1), xmask, eviction_policy='evict_last', other=0.0)
36
+ tmp1 = r0_2
37
+ tmp2 = tmp1.to(tl.int16)
38
+ tmp3 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
39
+ tmp4 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK])
40
+ tmp5, tmp6, = triton_helpers.sort_with_index(tmp3, tmp4, None, 1, stable=True, descending=True)
41
+ tmp7 = tmp0.to(tl.int64)
42
+ tmp8 = tl.broadcast_to(tmp7, [XBLOCK, R0_BLOCK])
43
+ tmp10 = tl.where(xmask, tmp8, 0)
44
+ tmp11 = tl.sum(tmp10, 1)[:, None].to(tl.int64)
45
+ tmp12 = tmp6.to(tl.int64)
46
+ tmp13 = tmp12.to(tl.int32)
47
+ tmp14 = tmp11.to(tl.int32)
48
+ tl.store(out_ptr2 + (r0_2 + 16*x0 + 16*x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp13, xmask)
49
+ tl.store(out_ptr3 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp14, xmask)
SpecForge-ext/cache/compiled_kernels/dq/cdqxxevdyssoyut2euw55y27cahqqcgmvyuhdihb4tmner7cfc7f.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.reduction(
11
+ size_hints={'x': 524288, 'r0_': 128},
12
+ reduction_hint=ReductionHint.DEFAULT,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr1': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 4194304, 'r0_': 268435456}}
16
+ )
17
+ @triton.jit
18
+ def triton_red_fused_zeros_0(in_ptr0, in_ptr1, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
19
+ xnumel = 524288
20
+ r0_numel = 128
21
+ rnumel = r0_numel
22
+ RBLOCK: tl.constexpr = R0_BLOCK
23
+ xoffset = tl.program_id(0) * XBLOCK
24
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
25
+ xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
26
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
27
+ rbase = r0_base
28
+ x0 = (xindex % 2048)
29
+ x1 = ((xindex // 2048) % 32)
30
+ x2 = xindex // 65536
31
+ x4 = xindex
32
+ _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
33
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
34
+ r0_index = r0_offset + r0_base
35
+ r0_mask = r0_index < r0_numel
36
+ roffset = r0_offset
37
+ rindex = r0_index
38
+ r0_3 = r0_index
39
+ tmp0 = tl.load(in_ptr0 + (r0_3 + 128*x1 + 4096*x0 + 8388608*x2), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
40
+ tmp1 = tl.load(in_ptr1 + (r0_3 + 128*x4), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
41
+ tmp2 = tmp0 * tmp1
42
+ tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK])
43
+ tmp5 = _tmp4 + tmp3
44
+ _tmp4 = tl.where(r0_mask, tmp5, _tmp4)
45
+ tmp4 = tl.sum(_tmp4, 1)[:, None]
46
+ tmp6 = tmp4.to(tl.float32)
47
+ tmp7 = 0.0
48
+ tmp8 = tmp6 - tmp7
49
+ tl.store(out_ptr1 + (x4), tmp8, None)
SpecForge-ext/cache/compiled_kernels/dq/e6aa9461d93df8973681493d15479cff1a0d8302c7a7de253f84ade82cf09c3e.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 1, "num_warps": 2, "num_stages": 1, "configs_hash": "6fcabd0411a839b7b5d117b5e6638bd1b5d7bc3379312c678d803859f08278a9", "found_by_coordesc": false, "time_taken_ms": 18, "triton_cache_hash": "G2LU7LIHIOEHQSWVLFBJATACJ76YHM672CUBUDGJGAJUEQVWVOFQ"}
SpecForge-ext/cache/compiled_kernels/dt/cdthlbsdpcqgxus7ldvwk23vvgojrmkgt7yidbhj27c2esjsap6w.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AOT ID: ['0_inference']
2
+ from ctypes import c_void_p, c_long, c_int
3
+ import torch
4
+ import math
5
+ import random
6
+ import os
7
+ import tempfile
8
+ from math import inf, nan
9
+ from cmath import nanj
10
+ from torch._inductor.hooks import run_intermediate_hooks
11
+ from torch._inductor.utils import maybe_profile
12
+ from torch._inductor.codegen.memory_planning import _align as align
13
+ from torch import device, empty_strided
14
+ from torch._inductor.async_compile import AsyncCompile
15
+ from torch._inductor.select_algorithm import extern_kernels
16
+ import triton
17
+ import triton.language as tl
18
+ from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
19
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
20
+
21
+ aten = torch.ops.aten
22
+ inductor_ops = torch.ops.inductor
23
+ _quantized = torch.ops._quantized
24
+ assert_size_stride = torch._C._dynamo.guards.assert_size_stride
25
+ assert_alignment = torch._C._dynamo.guards.assert_alignment
26
+ empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
27
+ empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
28
+ empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
29
+ empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
30
+ empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
31
+ reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
32
+ alloc_from_pool = torch.ops.inductor._alloc_from_pool
33
+ async_compile = AsyncCompile()
34
+ empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
35
+
36
+
37
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ky/ckymincticcpi6whoxumnurwhspwbrhpbcg34533u5yjkbf7m3oy.py
38
+ # Topologically Sorted Source Nodes: [target_max_token, target_mask, getitem_1, target_mask_1, position_mask], Original ATen: [aten.argmax, aten.index, aten.unsqueeze, aten._to_copy, aten.mul]
39
+ # Source node to ATen node mapping:
40
+ # getitem_1 => unsqueeze
41
+ # position_mask => mul
42
+ # target_mask => index
43
+ # target_mask_1 => convert_element_type
44
+ # target_max_token => argmax
45
+ # Graph fragment:
46
+ # %arg0_1 : Tensor "bf16[2, 2048, 151936][311164928, 151936, 1]cuda:3" = PlaceHolder[target=arg0_1]
47
+ # %argmax : Tensor "i64[2, 2048][2048, 1]cuda:3" = PlaceHolder[target=argmax]
48
+ # %arg1_1 : Tensor "b8[151936][1]cuda:3" = PlaceHolder[target=arg1_1]
49
+ # %arg2_1 : Tensor "i64[2, 2048, 1][2048, 1, 1]cuda:3" = PlaceHolder[target=arg2_1]
50
+ # %argmax : Tensor "i64[2, 2048][2048, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.argmax.default](args = (%arg0_1, -1), kwargs = {})
51
+ # %index : Tensor "b8[2, 2048][2048, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg1_1, [%argmax]), kwargs = {})
52
+ # %unsqueeze : Tensor "b8[2, 2048, 1][2048, 1, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index, 2), kwargs = {})
53
+ # %convert_element_type : Tensor "i32[2, 2048, 1][2048, 1, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%unsqueeze, torch.int32), kwargs = {})
54
+ # %mul : Tensor "i64[2, 2048, 1][2048, 1, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type, %arg2_1), kwargs = {})
55
+ # return %argmax,%mul
56
+ triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0 = async_compile.triton('triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0', '''
57
+ import triton
58
+ import triton.language as tl
59
+
60
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
61
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
62
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
63
+ triton_helpers.set_driver_to_gpu()
64
+
65
+ @triton_heuristics.reduction(
66
+ size_hints={'x': 4096, 'r0_': 262144},
67
+ reduction_hint=ReductionHint.INNER,
68
+ filename=__file__,
69
+ triton_meta={'signature': {'in_out_ptr0': '*i64', 'in_ptr0': '*bf16', 'in_ptr1': '*i1', 'in_ptr2': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]},
70
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
71
+ )
72
+ @triton.jit
73
+ def triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
74
+ xnumel = 4096
75
+ r0_numel = 151936
76
+ rnumel = r0_numel
77
+ RBLOCK: tl.constexpr = R0_BLOCK
78
+ xoffset = tl.program_id(0) * XBLOCK
79
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
80
+ xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
81
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
82
+ rbase = r0_base
83
+ x0 = xindex
84
+ _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32)
85
+ _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32)
86
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
87
+ r0_index = r0_offset + r0_base
88
+ r0_mask = r0_index < r0_numel
89
+ roffset = r0_offset
90
+ rindex = r0_index
91
+ r0_1 = r0_index
92
+ tmp0 = tl.load(in_ptr0 + (r0_1 + 151936*x0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
93
+ tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
94
+ _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index(
95
+ _tmp2, _tmp2_index, tmp1, rindex
96
+ )
97
+ _tmp2 = tl.where(r0_mask, _tmp2_next, _tmp2)
98
+ _tmp2_index = tl.where(r0_mask, _tmp2_index_next, _tmp2_index)
99
+ tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1)
100
+ tmp2 = tmp2_idx[:, None]
101
+ tmp11 = tl.load(in_ptr2 + (x0), None, eviction_policy='evict_last')
102
+ tmp3 = tl.full([XBLOCK, 1], 151936, tl.int32)
103
+ tmp4 = tmp2 + tmp3
104
+ tmp5 = tmp2 < 0
105
+ tmp6 = tl.where(tmp5, tmp4, tmp2)
106
+ tl.device_assert((0 <= tmp6) & (tmp6 < 151936), "index out of bounds: 0 <= tmp6 < 151936")
107
+ tmp8 = tl.load(in_ptr1 + (tmp6), None, eviction_policy='evict_last').to(tl.int1)
108
+ tmp9 = tmp8.to(tl.int32)
109
+ tmp10 = tmp9.to(tl.int64)
110
+ tmp12 = tmp10 * tmp11
111
+ tl.debug_barrier()
112
+ tl.store(in_out_ptr0 + (x0), tmp12, None)
113
+ ''', device_str='cuda')
114
+
115
+
116
+ async_compile.wait(globals())
117
+ del async_compile
118
+
119
+ class Runner:
120
+ def __init__(self, partitions):
121
+ self.partitions = partitions
122
+
123
+ def recursively_apply_fns(self, fns):
124
+ new_callables = []
125
+ for fn, c in zip(fns, self.partitions):
126
+ new_callables.append(fn(c))
127
+ self.partitions = new_callables
128
+
129
+ def call(self, args):
130
+ arg0_1, arg1_1, arg2_1 = args
131
+ args.clear()
132
+ assert_size_stride(arg0_1, (2, 2048, 151936), (311164928, 151936, 1))
133
+ assert_size_stride(arg1_1, (151936, ), (1, ))
134
+ assert_size_stride(arg2_1, (2, 2048, 1), (2048, 1, 1))
135
+ with torch.cuda._DeviceGuard(3):
136
+ torch.cuda.set_device(3)
137
+ buf0 = empty_strided_cuda((2, 2048), (2048, 1), torch.int64)
138
+ buf1 = reinterpret_tensor(buf0, (2, 2048, 1), (2048, 1, 1), 0); del buf0 # reuse
139
+ # Topologically Sorted Source Nodes: [target_max_token, target_mask, getitem_1, target_mask_1, position_mask], Original ATen: [aten.argmax, aten.index, aten.unsqueeze, aten._to_copy, aten.mul]
140
+ stream3 = get_raw_stream(3)
141
+ triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0.run(buf1, arg0_1, arg1_1, arg2_1, 4096, 151936, stream=stream3)
142
+ del arg0_1
143
+ del arg1_1
144
+ del arg2_1
145
+ return (buf1, )
146
+
147
+ runner = Runner(partitions=[])
148
+ call = runner.call
149
+ recursively_apply_fns = runner.recursively_apply_fns
150
+
151
+
152
+ def benchmark_compiled_module(times=10, repeat=10):
153
+ from torch._dynamo.testing import rand_strided
154
+ from torch._inductor.utils import print_performance
155
+ arg0_1 = rand_strided((2, 2048, 151936), (311164928, 151936, 1), device='cuda:3', dtype=torch.bfloat16)
156
+ arg1_1 = rand_strided((151936, ), (1, ), device='cuda:3', dtype=torch.bool)
157
+ arg2_1 = rand_strided((2, 2048, 1), (2048, 1, 1), device='cuda:3', dtype=torch.int64)
158
+ fn = lambda: call([arg0_1, arg1_1, arg2_1])
159
+ return print_performance(fn, times=times, repeat=repeat)
160
+
161
+
162
+ if __name__ == "__main__":
163
+ from torch._inductor.wrapper_benchmark import compiled_module_main
164
+ compiled_module_main('None', benchmark_compiled_module)
SpecForge-ext/cache/compiled_kernels/dt/cdtjh6gxoepiahz2caz7vmm66wc5rf2ib5iyvtxe3w7pr44tvvpt.py ADDED
@@ -0,0 +1,1051 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AOT ID: ['6_backward']
2
+ from ctypes import c_void_p, c_long, c_int
3
+ import torch
4
+ import math
5
+ import random
6
+ import os
7
+ import tempfile
8
+ from math import inf, nan
9
+ from cmath import nanj
10
+ from torch._inductor.hooks import run_intermediate_hooks
11
+ from torch._inductor.utils import maybe_profile
12
+ from torch._inductor.codegen.memory_planning import _align as align
13
+ from torch import device, empty_strided
14
+ from torch._inductor.async_compile import AsyncCompile
15
+ from torch._inductor.select_algorithm import extern_kernels
16
+ import triton
17
+ import triton.language as tl
18
+ from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
19
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
20
+
21
+ aten = torch.ops.aten
22
+ inductor_ops = torch.ops.inductor
23
+ _quantized = torch.ops._quantized
24
+ assert_size_stride = torch._C._dynamo.guards.assert_size_stride
25
+ assert_alignment = torch._C._dynamo.guards.assert_alignment
26
+ empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
27
+ empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
28
+ empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
29
+ empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
30
+ empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
31
+ reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
32
+ alloc_from_pool = torch.ops.inductor._alloc_from_pool
33
+ async_compile = AsyncCompile()
34
+ empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
35
+
36
+
37
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/kz/ckzf3m2manw23rbqxotolgimwqgjhy7lsthrid5s266iqj226dep.py
38
+ # Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros]
39
+ # Source node to ATen node mapping:
40
+ # Graph fragment:
41
+ # %getitem : Tensor "bf16[2, 32, 2048, 128][8388608, 128, 4096, 1]cuda:4" = PlaceHolder[target=getitem]
42
+ # %tangents_1 : Tensor "bf16[2, 32, 2048, 128][8388608, 262144, 128, 1]cuda:4" = PlaceHolder[target=tangents_1]
43
+ # %buf0 : Tensor "bf16[2, 32, 2048][65536, 2048, 1]cuda:4" = PlaceHolder[target=buf0]
44
+ # %full_default : Tensor "f32[2, 32, 2048][65536, 2048, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 32, 2048], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:4, pin_memory: False})
45
+ # %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_1, %primals_2, %primals_3, %getitem, %getitem_1, %tangents_1, %full_default, %fw_graph0, %joint_graph0, (2048, 2048, %primals_5, %primals_4, %primals_7, %primals_8, %primals_9, %primals_10, %primals_11, %primals_12, 128, 128, %mask_graph0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_6,)), kwargs = {})
46
+ # return %buf0,%buf1
47
+ triton_red_fused_zeros_0 = async_compile.triton('triton_red_fused_zeros_0', '''
48
+ import triton
49
+ import triton.language as tl
50
+
51
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
52
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
53
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
54
+ triton_helpers.set_driver_to_gpu()
55
+
56
+ @triton_heuristics.reduction(
57
+ size_hints={'x': 131072, 'r0_': 128},
58
+ reduction_hint=ReductionHint.DEFAULT,
59
+ filename=__file__,
60
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr1': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
61
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 1048576, 'r0_': 67108864}}
62
+ )
63
+ @triton.jit
64
+ def triton_red_fused_zeros_0(in_ptr0, in_ptr1, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
65
+ xnumel = 131072
66
+ r0_numel = 128
67
+ rnumel = r0_numel
68
+ RBLOCK: tl.constexpr = R0_BLOCK
69
+ xoffset = tl.program_id(0) * XBLOCK
70
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
71
+ xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
72
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
73
+ rbase = r0_base
74
+ x0 = (xindex % 2048)
75
+ x1 = ((xindex // 2048) % 32)
76
+ x2 = xindex // 65536
77
+ x4 = xindex
78
+ _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
79
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
80
+ r0_index = r0_offset + r0_base
81
+ r0_mask = r0_index < r0_numel
82
+ roffset = r0_offset
83
+ rindex = r0_index
84
+ r0_3 = r0_index
85
+ tmp0 = tl.load(in_ptr0 + (r0_3 + 128*x1 + 4096*x0 + 8388608*x2), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
86
+ tmp1 = tl.load(in_ptr1 + (r0_3 + 128*x4), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
87
+ tmp2 = tmp0 * tmp1
88
+ tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK])
89
+ tmp5 = _tmp4 + tmp3
90
+ _tmp4 = tl.where(r0_mask, tmp5, _tmp4)
91
+ tmp4 = tl.sum(_tmp4, 1)[:, None]
92
+ tmp6 = tmp4.to(tl.float32)
93
+ tmp7 = 0.0
94
+ tmp8 = tmp6 - tmp7
95
+ tl.store(out_ptr1 + (x4), tmp8, None)
96
+ ''', device_str='cuda')
97
+
98
+
99
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/cl/cclgog3gyib2chh5xgwqlrms5pk2giqv3sr2wpqipwovq6esktbk.py
100
+ # Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros]
101
+ # Source node to ATen node mapping:
102
+ # Graph fragment:
103
+ # %primals_1 : Tensor "bf16[2, 32, 2048, 128][8388608, 128, 4096, 1]cuda:4" = PlaceHolder[target=primals_1]
104
+ # %primals_2 : Tensor "bf16[2, 8, 2048, 128][2097152, 262144, 128, 1]cuda:4" = PlaceHolder[target=primals_2]
105
+ # %primals_3 : Tensor "bf16[2, 8, 2048, 128][2097152, 262144, 128, 1]cuda:4" = PlaceHolder[target=primals_3]
106
+ # %getitem_1 : Tensor "f32[2, 32, 2048][65536, 2048, 1]cuda:4" = PlaceHolder[target=getitem_1]
107
+ # %buf1 : Tensor "f32[2, 32, 2048][65536, 2048, 1]cuda:4" = PlaceHolder[target=buf1]
108
+ # %tangents_1 : Tensor "bf16[2, 32, 2048, 128][8388608, 262144, 128, 1]cuda:4" = PlaceHolder[target=tangents_1]
109
+ # %getitem_3 : Tensor "bf16[2, 32, 2048, 128][8388608, 128, 4096, 1]cuda:4" = PlaceHolder[target=getitem_3]
110
+ # %getitem_5 : Tensor "bf16[2, 8, 2048, 128][2097152, 262144, 128, 1]cuda:4" = PlaceHolder[target=getitem_5]
111
+ # %primals_5 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:4" = PlaceHolder[target=primals_5]
112
+ # %primals_4 : Tensor "i32[2, 1, 16, 16][256, 256, 16, 1]cuda:4" = PlaceHolder[target=primals_4]
113
+ # %primals_9 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:4" = PlaceHolder[target=primals_9]
114
+ # %primals_10 : Tensor "i32[2, 1, 16, 16][256, 256, 16, 1]cuda:4" = PlaceHolder[target=primals_10]
115
+ # %primals_7 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:4" = PlaceHolder[target=primals_7]
116
+ # %primals_8 : Tensor "i32[2, 1, 16, 16][256, 256, 16, 1]cuda:4" = PlaceHolder[target=primals_8]
117
+ # %primals_11 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:4" = PlaceHolder[target=primals_11]
118
+ # %primals_12 : Tensor "i32[2, 1, 16, 16][256, 256, 16, 1]cuda:4" = PlaceHolder[target=primals_12]
119
+ # %primals_6 : Tensor "i64[2][1]cuda:4" = PlaceHolder[target=primals_6]
120
+ # %full_default : Tensor "f32[2, 32, 2048][65536, 2048, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 32, 2048], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:4, pin_memory: False})
121
+ # %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_1, %primals_2, %primals_3, %getitem, %getitem_1, %tangents_1, %full_default, %fw_graph0, %joint_graph0, (2048, 2048, %primals_5, %primals_4, %primals_7, %primals_8, %primals_9, %primals_10, %primals_11, %primals_12, 128, 128, %mask_graph0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_6,)), kwargs = {})
122
+ # return %getitem_4
123
+ triton_tem_fused_zeros_1 = async_compile.triton('triton_tem_fused_zeros_1', '''
124
+ import triton
125
+ import triton.language as tl
126
+
127
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
128
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
129
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
130
+
131
+ @triton_heuristics.template(
132
+
133
+ num_stages=3,
134
+ num_warps=8,
135
+ triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]},
136
+ inductor_meta={'kernel_name': 'triton_tem_fused_zeros_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
137
+
138
+ )
139
+ @triton.jit
140
+ def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0):
141
+ PRESCALE_QK : tl.constexpr = False
142
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
143
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
144
+ WRITE_DQ : tl.constexpr = True
145
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
146
+ OUTPUT_MAX : tl.constexpr = False
147
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
148
+ IS_DIVISIBLE : tl.constexpr = True
149
+ SM_SCALE : tl.constexpr = 0.08838834764831843
150
+ GQA_SHARED_HEADS : tl.constexpr = 4
151
+ HAS_FULL_BLOCKS : tl.constexpr = True
152
+ QK_HEAD_DIM : tl.constexpr = 128
153
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
154
+ V_HEAD_DIM : tl.constexpr = 128
155
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
156
+ SAFE_HEAD_DIM : tl.constexpr = True
157
+ BLOCK_M1 : tl.constexpr = 64
158
+ BLOCK_N1 : tl.constexpr = 128
159
+ BLOCK_M2 : tl.constexpr = 128
160
+ BLOCK_N2 : tl.constexpr = 64
161
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
162
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
163
+ INDEX_DTYPE : tl.constexpr = tl.int32
164
+ Q = arg_Q
165
+ K = arg_K
166
+ V = arg_V
167
+ LSE = arg_LSE
168
+ DELTA = arg_DELTA
169
+ DO = arg_DO
170
+ DQ = arg_DQ
171
+ DV = arg_DV
172
+ KV_NUM_BLKS = arg_KV_NUM_BLKS
173
+ KV_IDX = arg_KV_IDX
174
+ Q_NUM_BLKS = arg_Q_NUM_BLKS
175
+ Q_IDX = arg_Q_IDX
176
+ FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
177
+ FULL_KV_IDX = arg_FULL_KV_IDX
178
+ FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS
179
+ FULL_Q_IDX = arg_FULL_Q_IDX
180
+
181
+ # Sub notation for this kernel:
182
+ #
183
+ # Q: Query, K: Key, V: Value
184
+ # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype)
185
+ # DELTA: Precomputed sum(OUT*DO, axis=-1)
186
+ # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value
187
+ # DK: Derivative of Key, is the written to via the store_output call due to some limitations with
188
+ # inductor codegen
189
+ # M: Number of queries, N: Number of keys/values
190
+ # QK_HEAD_DIM: The dimension of the query and key embeddings
191
+ # V_HEAD_DIM: The dimension of the value embeddings
192
+ # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim
193
+ # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
194
+ # (Modifiable) Performance tuning options
195
+ # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block.
196
+ # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V.
197
+ # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q.
198
+ # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block.
199
+ #
200
+ # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
201
+ # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
202
+ # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
203
+ # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query.
204
+ # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query.
205
+ # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
206
+ # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
207
+ # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query.
208
+ # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query.
209
+
210
+ # The below are kernel options that can be applied for certain score_mods,
211
+ # or involve a numerics vs. perf tradeoff
212
+ # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
213
+ # about 20% more numerical error, but slightly faster.
214
+
215
+ # Define strides of inputs
216
+ stride_qz, stride_qh, stride_qm, stride_qd = 8388608, 128, 4096, 1
217
+ stride_kz, stride_kh, stride_kn, stride_kd = 2097152, 262144, 128, 1
218
+ stride_vz, stride_vh, stride_vn, stride_vd = 2097152, 262144, 128, 1
219
+ stride_doz, stride_doh, stride_dom, stride_dod = 8388608, 262144, 128, 1
220
+
221
+ stride_dqz, stride_dqh, stride_dqm, stride_dqd = 8388608, 128, 4096, 1
222
+ stride_dvz, stride_dvh, stride_dvm, stride_dvd = 2097152, 262144, 128, 1
223
+
224
+ ZQ = 2
225
+ HQ = 32
226
+ HKV = 8
227
+ Q_LEN = 2048
228
+ ZKV = 2
229
+ KV_LEN = 2048
230
+
231
+ MATMUL_PRECISION = Q.dtype.element_ty
232
+
233
+ pid = tl.program_id(0).to(INDEX_DTYPE)
234
+ NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1)
235
+ NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2)
236
+
237
+ off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx
238
+ off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx
239
+ off_zkv = off_zq % ZKV # kv batch idx
240
+
241
+ SPARSE_Z = 2
242
+ SPARSE_HQ = 1
243
+
244
+ sparse_idx_z = off_zq % SPARSE_Z
245
+
246
+ k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64)
247
+ v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64)
248
+ # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
249
+ # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
250
+ dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64)
251
+
252
+ # offset K, V, DV pointers for batch/kv-head
253
+ K += k_adj
254
+ V += v_adj
255
+ DV += dv_adj
256
+
257
+ RCP_LN2 = 1.44269504
258
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
259
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
260
+
261
+ if pid >= NUM_KV_BLOCKS:
262
+ off_pid = pid - NUM_KV_BLOCKS
263
+ # THIS BLOCK DOES DQ
264
+ SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2)
265
+ SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
266
+ off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS
267
+ start_m2_block = off_pid % NUM_Q_BLOCKS
268
+ off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE
269
+ stride_kv_num_blks_h = 16
270
+ stride_kv_idx_h = 256
271
+ stride_kv_idx_m = 16
272
+
273
+ sparse_idx_hq2 = off_hq2 % SPARSE_HQ
274
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2
275
+
276
+ sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask
277
+ sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950
278
+
279
+ # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
280
+ q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64)
281
+ do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64)
282
+ dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64)
283
+ off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64)
284
+
285
+ Q2 = Q + q_adj2
286
+ DO2 = DO + do_adj2
287
+ # TODO: This does not work if DQ is not the same layout as Q (for example,
288
+ # if Q is broadcasted)
289
+ DQ2 = DQ + dq_adj2
290
+ LSE2 = LSE + off_chz2
291
+ DELTA2 = DELTA + off_chz2
292
+
293
+ # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32)
294
+ dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
295
+
296
+ start_m2 = start_m2_block * BLOCK_M2
297
+ offs_m2 = start_m2 + tl.arange(0, BLOCK_M2)
298
+
299
+ # load Q and do: they stay in SRAM throughout the inner loop.
300
+ q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
301
+ do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
302
+
303
+ if PRESCALE_QK:
304
+ q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
305
+
306
+ if IS_DIVISIBLE:
307
+ Di = tl.load(DELTA2 + offs_m2)
308
+ lse = tl.load(LSE2 + offs_m2)
309
+ else:
310
+ Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN)
311
+ lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN)
312
+ lse = tl.where(lse == -float("inf"), 0.0, lse)
313
+ lse = lse[:, None]
314
+
315
+ # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
316
+ # KV_IDX and KV_NUM_BLKS are always contiguous.
317
+ kv_indices = KV_IDX + sparse_kv_idx_offset
318
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
319
+ sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
320
+
321
+ offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
322
+ dq = bwd_dq_inner(
323
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
324
+ K, V,
325
+ dq, q, do, Di, lse,
326
+ off_zq, off_hq2, offs_m2, offs_n2,
327
+ stride_kn, stride_kd, stride_vn, stride_vd,
328
+ kv_indices, sparse_kv_num_blocks,
329
+ MATMUL_PRECISION,
330
+ IS_FULL_BLOCKS=False,
331
+ )
332
+
333
+ if HAS_FULL_BLOCKS:
334
+ # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
335
+ # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
336
+ kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
337
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
338
+ sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
339
+
340
+ offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
341
+ dq = bwd_dq_inner(
342
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
343
+ K, V,
344
+ dq, q, do, Di, lse,
345
+ off_zq, off_hq2, offs_m2, offs_n2,
346
+ stride_kn, stride_kd, stride_vn, stride_vd,
347
+ kv_indices, sparse_kv_num_blocks,
348
+ MATMUL_PRECISION,
349
+ IS_FULL_BLOCKS=True,
350
+ )
351
+
352
+ # Write back dQ.
353
+ dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd
354
+ dq *= SM_SCALE
355
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
356
+ tl.store(dq_ptrs, dq)
357
+ else:
358
+ tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM))
359
+ else:
360
+ # THIS BLOCK DOES DK & DV
361
+ SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
362
+ SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1)
363
+
364
+ pid_mask = pid // SPARSE_KV_MULTIPLE
365
+
366
+ stride_q_num_blks_h = 16
367
+ stride_q_idx_h = 256
368
+ stride_q_idx_n = 16
369
+
370
+
371
+ dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
372
+ dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
373
+
374
+ start_n1 = pid * BLOCK_N1
375
+ offs_n1 = start_n1 + tl.arange(0, BLOCK_N1)
376
+
377
+ # load K and V: they stay in SRAM throughout the inner loop.
378
+ k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
379
+ v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
380
+
381
+ if PRESCALE_QK:
382
+ k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
383
+
384
+ for off_g in range(0, GQA_SHARED_HEADS):
385
+ off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g
386
+
387
+ # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
388
+ q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64)
389
+ do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64)
390
+ dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64)
391
+ off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64)
392
+
393
+ Q1 = Q + q_adj1
394
+ DO1 = DO + do_adj1
395
+ # TODO: This does not work if DQ is not the same layout as Q (for example,
396
+ # if Q is broadcasted)
397
+ LSE1 = LSE + off_chz1
398
+ DELTA1 = DELTA + off_chz1
399
+
400
+ sparse_idx_hq1 = off_hq1 % SPARSE_HQ
401
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1
402
+
403
+ sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask
404
+ sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950
405
+
406
+ # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
407
+ # Q_IDX and Q_NUM_BLKS are always contiguous.
408
+ q_indices = Q_IDX + sparse_q_idx_offset
409
+ q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
410
+ sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset)
411
+
412
+ offs_m1 = q_start + tl.arange(0, BLOCK_M1)
413
+ dk, dv = bwd_dkdv_inner(
414
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
415
+ Q1, DO1, DELTA1, LSE1,
416
+ dk, dv, k, v,
417
+ off_zq, off_hq1, offs_n1, offs_m1,
418
+ stride_qm, stride_qd, stride_dom, stride_dod,
419
+ q_indices, sparse_q_num_blocks,
420
+ MATMUL_PRECISION,
421
+ IS_FULL_BLOCKS=False,
422
+ )
423
+
424
+
425
+ if HAS_FULL_BLOCKS:
426
+ # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
427
+ # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous.
428
+ q_indices = FULL_Q_IDX + sparse_q_idx_offset
429
+ q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
430
+ sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset)
431
+
432
+ offs_m1 = q_start + tl.arange(0, BLOCK_M1)
433
+ dk, dv = bwd_dkdv_inner(
434
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
435
+ Q1, DO1, DELTA1, LSE1,
436
+ dk, dv, k, v,
437
+ off_zq, off_hq1, offs_n1, offs_m1,
438
+ stride_qm, stride_qd, stride_dom, stride_dod,
439
+ q_indices, sparse_q_num_blocks,
440
+ MATMUL_PRECISION,
441
+ IS_FULL_BLOCKS=True,
442
+ )
443
+
444
+ # Write back dV and dK.
445
+ dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd
446
+
447
+ index_n = offs_n1[:, None]
448
+ index_k = offs_k[None, :]
449
+ index_v = offs_v[None, :]
450
+
451
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
452
+ tl.store(dv_ptrs, dv)
453
+ else:
454
+ tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM))
455
+
456
+ dk *= SM_SCALE
457
+
458
+ if SAFE_HEAD_DIM:
459
+ mask = index_n < KV_LEN
460
+ else:
461
+ mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM)
462
+
463
+ # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
464
+ # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
465
+ tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED])
466
+ xindex = index_k + 128*index_n + 262144*off_hkv + 2097152*off_zq
467
+ tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask)
468
+
469
+ @triton.jit
470
+ def bwd_dq_inner(
471
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
472
+ K, V, # pointers
473
+ dq, q, do, Di, lse,
474
+ off_z, off_hq, offs_m2, offs_n2,
475
+ stride_kn, stride_kd, stride_vn, stride_vd,
476
+ kv_indices, sparse_kv_num_blocks,
477
+ MATMUL_PRECISION,
478
+ IS_FULL_BLOCKS,
479
+ ):
480
+ PRESCALE_QK : tl.constexpr = False
481
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
482
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
483
+ WRITE_DQ : tl.constexpr = True
484
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
485
+ OUTPUT_MAX : tl.constexpr = False
486
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
487
+ IS_DIVISIBLE : tl.constexpr = True
488
+ SM_SCALE : tl.constexpr = 0.08838834764831843
489
+ GQA_SHARED_HEADS : tl.constexpr = 4
490
+ HAS_FULL_BLOCKS : tl.constexpr = True
491
+ QK_HEAD_DIM : tl.constexpr = 128
492
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
493
+ V_HEAD_DIM : tl.constexpr = 128
494
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
495
+ SAFE_HEAD_DIM : tl.constexpr = True
496
+ BLOCK_M1 : tl.constexpr = 64
497
+ BLOCK_N1 : tl.constexpr = 128
498
+ BLOCK_M2 : tl.constexpr = 128
499
+ BLOCK_N2 : tl.constexpr = 64
500
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
501
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
502
+ INDEX_DTYPE : tl.constexpr = tl.int32
503
+
504
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
505
+ RCP_LN2: tl.constexpr = 1.44269504
506
+ Q_LEN = 2048
507
+ KV_LEN = 2048
508
+
509
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
510
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
511
+
512
+ kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd
513
+ vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd
514
+ # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
515
+ tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
516
+
517
+ hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1))
518
+
519
+ for start_n in range(0, hi):
520
+ dq = bwd_dq_block_mn(
521
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
522
+ dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
523
+ off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
524
+ stride_kn, stride_kd, stride_vn, stride_vd,
525
+ kv_indices, sparse_kv_num_blocks,
526
+ MATMUL_PRECISION, RCP_LN2,
527
+ IS_FULL_BLOCKS,
528
+ )
529
+
530
+ # Increment pointers.
531
+ offset = get_offset_for_next_block(
532
+ start_n, kv_indices, sparse_kv_num_blocks,
533
+ SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS
534
+ )
535
+
536
+ kT_ptrs += offset * stride_kn
537
+ vT_ptrs += offset * stride_vn
538
+
539
+ offs_n2 += offset
540
+
541
+ return dq
542
+
543
+
544
+ @triton.jit
545
+ def bwd_dq_block_mn(
546
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
547
+ dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
548
+ off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
549
+ stride_kn, stride_kd, stride_vn, stride_vd,
550
+ kv_indices, sparse_kv_num_blocks,
551
+ MATMUL_PRECISION, RCP_LN2,
552
+ IS_FULL_BLOCKS,
553
+ ):
554
+ PRESCALE_QK : tl.constexpr = False
555
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
556
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
557
+ WRITE_DQ : tl.constexpr = True
558
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
559
+ OUTPUT_MAX : tl.constexpr = False
560
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
561
+ IS_DIVISIBLE : tl.constexpr = True
562
+ SM_SCALE : tl.constexpr = 0.08838834764831843
563
+ GQA_SHARED_HEADS : tl.constexpr = 4
564
+ HAS_FULL_BLOCKS : tl.constexpr = True
565
+ QK_HEAD_DIM : tl.constexpr = 128
566
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
567
+ V_HEAD_DIM : tl.constexpr = 128
568
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
569
+ SAFE_HEAD_DIM : tl.constexpr = True
570
+ BLOCK_M1 : tl.constexpr = 64
571
+ BLOCK_N1 : tl.constexpr = 128
572
+ BLOCK_M2 : tl.constexpr = 128
573
+ BLOCK_N2 : tl.constexpr = 64
574
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
575
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
576
+ INDEX_DTYPE : tl.constexpr = tl.int32
577
+
578
+
579
+ # NB reversed order to since K is transposed
580
+ kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN)
581
+ qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION)
582
+ if not PRESCALE_QK:
583
+ qk *= SM_SCALE
584
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
585
+ pre_mod_scores = qk
586
+ n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None)
587
+ # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim
588
+ # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary
589
+ m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None)
590
+
591
+ tmp0 = (qk)
592
+ post_mod_scores = tmp0
593
+
594
+
595
+
596
+
597
+ if not IS_DIVISIBLE:
598
+ post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf"))
599
+
600
+ if not IS_FULL_BLOCKS:
601
+ tmp1 = tl.full([1], False, tl.int1)
602
+ tmp2 = (m)
603
+ tmp3 = (n)
604
+ tmp4 = tmp2 >= tmp3
605
+ tmp5 = tmp3.to(tl.int64)
606
+ tmp6 = (off_z)
607
+ tmp7 = tl.load(in_ptr16 + tmp6)
608
+ tmp8 = tmp5 < tmp7
609
+ tmp9 = tmp2.to(tl.int64)
610
+ tmp10 = tmp9 < tmp7
611
+ tmp11 = tmp8 & tmp10
612
+ tmp12 = tmp4 & tmp11
613
+ tmp13 = tmp1 | tmp12
614
+ tmp14 = tl.full([1], 2048, tl.int32)
615
+ tmp15 = tmp3 >= tmp14
616
+ tmp16 = (tmp3 % tmp14)
617
+ tmp17 = tl.full([1], 0, tl.int32)
618
+ tmp18 = tmp16 != tmp17
619
+ tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0
620
+ tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0
621
+ tmp21 = tmp19 != tmp20
622
+ tmp22 = tmp18 & tmp21
623
+ tmp23 = tmp16 + tmp14
624
+ tmp24 = tl.where(tmp22, tmp23, tmp16)
625
+ tmp25 = tmp24.to(tl.int64)
626
+ tmp26 = tmp25 < tmp7
627
+ tmp27 = tmp15 & tmp26
628
+ tmp28 = tmp3 - tmp2
629
+ tmp29 = (tmp28 % tmp14)
630
+ tmp30 = tmp29 != tmp17
631
+ tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0
632
+ tmp32 = tmp31 != tmp20
633
+ tmp33 = tmp30 & tmp32
634
+ tmp34 = tmp29 + tmp14
635
+ tmp35 = tl.where(tmp33, tmp34, tmp29)
636
+ tmp36 = tmp35 == tmp17
637
+ tmp37 = tmp27 & tmp36
638
+ tmp38 = tmp13 | tmp37
639
+ mask_mod_output = tmp38
640
+
641
+
642
+ # apply mask for partial masked block
643
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
644
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
645
+ if not PRESCALE_QK:
646
+ post_mod_scores *= RCP_LN2
647
+ p = tl.math.exp2(post_mod_scores - lse)
648
+ # Compute dP and dS.
649
+ # NB reversed order to since V is transposed
650
+ vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN)
651
+
652
+ dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION)
653
+ ds = p * (dp - Di[:, None])
654
+ # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
655
+ tmp39 = (ds)
656
+ grad_scores = tmp39
657
+
658
+
659
+ if not IS_DIVISIBLE:
660
+ grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0)
661
+
662
+ # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
663
+ if WRITE_DQ:
664
+ scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN)
665
+
666
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
667
+ ds = grad_scores
668
+
669
+ if not IS_FULL_BLOCKS:
670
+ # (grads) apply mask for partially unmasked block
671
+ ds = tl.where(mask_mod_output, ds, 0.0)
672
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
673
+ ds = ds.to(MATMUL_PRECISION)
674
+ # Compute dQ.
675
+ dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION)
676
+
677
+ return dq
678
+
679
+
680
+ @triton.jit
681
+ def bwd_dkdv_inner(
682
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
683
+ Q, DO, DELTA, LSE, # pointers
684
+ dk, dv, k, v,
685
+ off_z, off_hq, offs_n1, offs_m1,
686
+ stride_qm, stride_qd, stride_dom, stride_dod,
687
+ q_indices, sparse_q_num_blocks,
688
+ MATMUL_PRECISION,
689
+ IS_FULL_BLOCKS,
690
+ ):
691
+ PRESCALE_QK : tl.constexpr = False
692
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
693
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
694
+ WRITE_DQ : tl.constexpr = True
695
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
696
+ OUTPUT_MAX : tl.constexpr = False
697
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
698
+ IS_DIVISIBLE : tl.constexpr = True
699
+ SM_SCALE : tl.constexpr = 0.08838834764831843
700
+ GQA_SHARED_HEADS : tl.constexpr = 4
701
+ HAS_FULL_BLOCKS : tl.constexpr = True
702
+ QK_HEAD_DIM : tl.constexpr = 128
703
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
704
+ V_HEAD_DIM : tl.constexpr = 128
705
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
706
+ SAFE_HEAD_DIM : tl.constexpr = True
707
+ BLOCK_M1 : tl.constexpr = 64
708
+ BLOCK_N1 : tl.constexpr = 128
709
+ BLOCK_M2 : tl.constexpr = 128
710
+ BLOCK_N2 : tl.constexpr = 64
711
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
712
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
713
+ INDEX_DTYPE : tl.constexpr = tl.int32
714
+
715
+ SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
716
+ RCP_LN2: tl.constexpr = 1.44269504
717
+ Q_LEN = 2048
718
+ KV_LEN = 2048
719
+
720
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
721
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
722
+
723
+ qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd
724
+ do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod
725
+ # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
726
+ tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
727
+
728
+ # The minimum is needed to handle the case where we run with a super large
729
+ # SPARSE_BLOCK_SIZE (i.e. no block-mask!)
730
+ hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1))
731
+
732
+ for start_m in range(0, hi):
733
+ dk, dv = bwd_dkdv_block_mn(
734
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
735
+ dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
736
+ off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
737
+ stride_qm, stride_qd, stride_dom, stride_dod,
738
+ q_indices, sparse_q_num_blocks,
739
+ MATMUL_PRECISION, RCP_LN2,
740
+ IS_FULL_BLOCKS,
741
+ )
742
+ # Increment pointers.
743
+ offset = get_offset_for_next_block(
744
+ start_m, q_indices, sparse_q_num_blocks,
745
+ SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS
746
+ )
747
+
748
+ qT_ptrs += offset * stride_qm
749
+ do_ptrs += offset * stride_dom
750
+ offs_m1 += offset
751
+
752
+ return dk, dv
753
+
754
+
755
+ @triton.jit
756
+ def bwd_dkdv_block_mn(
757
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
758
+ dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
759
+ off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
760
+ stride_qm, stride_qd, stride_dom, stride_dod,
761
+ q_indices, sparse_q_num_blocks,
762
+ MATMUL_PRECISION, RCP_LN2,
763
+ IS_FULL_BLOCKS,
764
+ ):
765
+ PRESCALE_QK : tl.constexpr = False
766
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
767
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
768
+ WRITE_DQ : tl.constexpr = True
769
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
770
+ OUTPUT_MAX : tl.constexpr = False
771
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
772
+ IS_DIVISIBLE : tl.constexpr = True
773
+ SM_SCALE : tl.constexpr = 0.08838834764831843
774
+ GQA_SHARED_HEADS : tl.constexpr = 4
775
+ HAS_FULL_BLOCKS : tl.constexpr = True
776
+ QK_HEAD_DIM : tl.constexpr = 128
777
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
778
+ V_HEAD_DIM : tl.constexpr = 128
779
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
780
+ SAFE_HEAD_DIM : tl.constexpr = True
781
+ BLOCK_M1 : tl.constexpr = 64
782
+ BLOCK_N1 : tl.constexpr = 128
783
+ BLOCK_M2 : tl.constexpr = 128
784
+ BLOCK_N2 : tl.constexpr = 64
785
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
786
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
787
+ INDEX_DTYPE : tl.constexpr = tl.int32
788
+
789
+
790
+ # NB reversed order since Q is transposed
791
+ qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN)
792
+ # Load LSE before computing qk to reduce pipeline stall.
793
+ if IS_DIVISIBLE:
794
+ lse = tl.load(LSE + offs_m1)
795
+ else:
796
+ lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN)
797
+ lse = tl.where(lse == -float("inf"), 0.0, lse)
798
+ qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION)
799
+ if not PRESCALE_QK:
800
+ qkT *= SM_SCALE
801
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
802
+ m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None)
803
+ # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim
804
+ # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary
805
+ n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None)
806
+
807
+ pre_mod_scores = qkT
808
+ tmp40 = (qkT)
809
+ post_mod_scores = tmp40
810
+
811
+
812
+
813
+ if not IS_DIVISIBLE:
814
+ post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf"))
815
+
816
+ if not IS_FULL_BLOCKS:
817
+ tmp41 = tl.full([1], False, tl.int1)
818
+ tmp42 = (m)
819
+ tmp43 = (n)
820
+ tmp44 = tmp42 >= tmp43
821
+ tmp45 = tmp43.to(tl.int64)
822
+ tmp46 = (off_z)
823
+ tmp47 = tl.load(in_ptr16 + tmp46)
824
+ tmp48 = tmp45 < tmp47
825
+ tmp49 = tmp42.to(tl.int64)
826
+ tmp50 = tmp49 < tmp47
827
+ tmp51 = tmp48 & tmp50
828
+ tmp52 = tmp44 & tmp51
829
+ tmp53 = tmp41 | tmp52
830
+ tmp54 = tl.full([1], 2048, tl.int32)
831
+ tmp55 = tmp43 >= tmp54
832
+ tmp56 = (tmp43 % tmp54)
833
+ tmp57 = tl.full([1], 0, tl.int32)
834
+ tmp58 = tmp56 != tmp57
835
+ tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0
836
+ tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0
837
+ tmp61 = tmp59 != tmp60
838
+ tmp62 = tmp58 & tmp61
839
+ tmp63 = tmp56 + tmp54
840
+ tmp64 = tl.where(tmp62, tmp63, tmp56)
841
+ tmp65 = tmp64.to(tl.int64)
842
+ tmp66 = tmp65 < tmp47
843
+ tmp67 = tmp55 & tmp66
844
+ tmp68 = tmp43 - tmp42
845
+ tmp69 = (tmp68 % tmp54)
846
+ tmp70 = tmp69 != tmp57
847
+ tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0
848
+ tmp72 = tmp71 != tmp60
849
+ tmp73 = tmp70 & tmp72
850
+ tmp74 = tmp69 + tmp54
851
+ tmp75 = tl.where(tmp73, tmp74, tmp69)
852
+ tmp76 = tmp75 == tmp57
853
+ tmp77 = tmp67 & tmp76
854
+ tmp78 = tmp53 | tmp77
855
+ mask_mod_output = tmp78
856
+
857
+ # (grads) apply mask for fully masked block
858
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
859
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
860
+ if not PRESCALE_QK:
861
+ post_mod_scores *= RCP_LN2
862
+ pT = tl.math.exp2(post_mod_scores - lse[None, :])
863
+ do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
864
+ # Compute dV.
865
+ ppT = pT
866
+ dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION)
867
+ if IS_DIVISIBLE:
868
+ Di = tl.load(DELTA + offs_m1)
869
+ else:
870
+ Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN)
871
+ # Compute dP and dS.
872
+ dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION)
873
+ dsT = pT * (dpT - Di[None, :])
874
+ # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
875
+ tmp79 = (dsT)
876
+ grad_scores = tmp79
877
+
878
+
879
+
880
+ if not IS_DIVISIBLE:
881
+ grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0)
882
+
883
+ # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
884
+ if not WRITE_DQ:
885
+ idx_b = off_z
886
+ idx_h = off_hq
887
+ idx_m = m
888
+ idx_n = n
889
+ scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN)
890
+
891
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
892
+ dsT = grad_scores
893
+ if not IS_FULL_BLOCKS:
894
+ # (grads) apply mask for partially unmasked block
895
+ dsT = tl.where(mask_mod_output, dsT, 0.0)
896
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
897
+ dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION)
898
+
899
+ return dk, dv
900
+
901
+ # Utility triton funcs
902
+ @triton.jit
903
+ def get_offset_for_next_block(
904
+ loop_iter, col_indices, total_blocks,
905
+ SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
906
+ BLOCKS_ARE_CONTIGUOUS: tl.constexpr
907
+ ):
908
+ if BLOCKS_ARE_CONTIGUOUS:
909
+ return BLOCK
910
+ cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
911
+ cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
912
+ next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
913
+ needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
914
+ jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
915
+ offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
916
+ return offset
917
+
918
+ @triton.jit
919
+ def get_bounded_indices(indices, max_len=None):
920
+ return indices % max_len if max_len is not None else indices
921
+
922
+ @triton.jit
923
+ def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
924
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
925
+ return tl.load(block_ptr)
926
+ elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
927
+ return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
928
+ elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
929
+ return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
930
+ else:
931
+ return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
932
+
933
+ @triton.jit
934
+ def load_checked_2d(
935
+ ptr,
936
+ offs_m,
937
+ offs_n,
938
+ stride_m,
939
+ stride_n,
940
+ IS_DIVISIBLE_M: tl.constexpr,
941
+ IS_DIVISIBLE_N: tl.constexpr,
942
+ M_LEN: tl.constexpr,
943
+ N_LEN: tl.constexpr,
944
+ ):
945
+ # Calculate final pointer if strides are provided
946
+ if stride_m is not None and stride_n is not None:
947
+ ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
948
+
949
+ # Handle all masking cases
950
+ if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
951
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
952
+ elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
953
+ return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
954
+ elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
955
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
956
+ else: # Both divisible
957
+ return tl.load(ptr)
958
+ ''', device_str='cuda')
959
+
960
+
961
+ async_compile.wait(globals())
962
+ del async_compile
963
+
964
+ class Runner:
965
+ def __init__(self, partitions):
966
+ self.partitions = partitions
967
+
968
+ def recursively_apply_fns(self, fns):
969
+ new_callables = []
970
+ for fn, c in zip(fns, self.partitions):
971
+ new_callables.append(fn(c))
972
+ self.partitions = new_callables
973
+
974
+ def call(self, args):
975
+ primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, getitem, getitem_1, tangents_1 = args
976
+ args.clear()
977
+ assert_size_stride(primals_1, (2, 32, 2048, 128), (8388608, 128, 4096, 1))
978
+ assert_size_stride(primals_2, (2, 8, 2048, 128), (2097152, 262144, 128, 1))
979
+ assert_size_stride(primals_3, (2, 8, 2048, 128), (2097152, 262144, 128, 1))
980
+ assert_size_stride(primals_4, (2, 1, 16, 16), (256, 256, 16, 1))
981
+ assert_size_stride(primals_5, (2, 1, 16), (16, 16, 1))
982
+ assert_size_stride(primals_6, (2, ), (1, ))
983
+ assert_size_stride(primals_7, (2, 1, 16), (16, 16, 1))
984
+ assert_size_stride(primals_8, (2, 1, 16, 16), (256, 256, 16, 1))
985
+ assert_size_stride(primals_9, (2, 1, 16), (16, 16, 1))
986
+ assert_size_stride(primals_10, (2, 1, 16, 16), (256, 256, 16, 1))
987
+ assert_size_stride(primals_11, (2, 1, 16), (16, 16, 1))
988
+ assert_size_stride(primals_12, (2, 1, 16, 16), (256, 256, 16, 1))
989
+ assert_size_stride(getitem, (2, 32, 2048, 128), (8388608, 128, 4096, 1))
990
+ assert_size_stride(getitem_1, (2, 32, 2048), (65536, 2048, 1))
991
+ assert_size_stride(tangents_1, (2, 32, 2048, 128), (8388608, 262144, 128, 1))
992
+ with torch.cuda._DeviceGuard(4):
993
+ torch.cuda.set_device(4)
994
+ buf1 = empty_strided_cuda((2, 32, 2048), (65536, 2048, 1), torch.float32)
995
+ # Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros]
996
+ stream4 = get_raw_stream(4)
997
+ triton_red_fused_zeros_0.run(getitem, tangents_1, buf1, 131072, 128, stream=stream4)
998
+ del getitem
999
+ buf3 = empty_strided_cuda((2, 32, 2048, 128), (8388608, 128, 4096, 1), torch.bfloat16)
1000
+ buf4 = empty_strided_cuda((2, 8, 2048, 128), (2097152, 262144, 128, 1), torch.bfloat16)
1001
+ buf5 = empty_strided_cuda((2, 8, 2048, 128), (2097152, 262144, 128, 1), torch.bfloat16)
1002
+ # Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros]
1003
+ stream4 = get_raw_stream(4)
1004
+ triton_tem_fused_zeros_1.run(primals_1, primals_2, primals_3, getitem_1, buf1, tangents_1, buf3, buf4, primals_5, primals_4, primals_9, primals_10, primals_7, primals_8, primals_11, primals_12, primals_6, buf5, 80, 2, 8, stream=stream4)
1005
+ del buf1
1006
+ del getitem_1
1007
+ del primals_1
1008
+ del primals_10
1009
+ del primals_11
1010
+ del primals_12
1011
+ del primals_2
1012
+ del primals_3
1013
+ del primals_4
1014
+ del primals_5
1015
+ del primals_6
1016
+ del primals_7
1017
+ del primals_8
1018
+ del primals_9
1019
+ del tangents_1
1020
+ return (buf3, buf5, buf4, None, None, None, None, None, None, None, None, None, )
1021
+
1022
+ runner = Runner(partitions=[])
1023
+ call = runner.call
1024
+ recursively_apply_fns = runner.recursively_apply_fns
1025
+
1026
+
1027
+ def benchmark_compiled_module(times=10, repeat=10):
1028
+ from torch._dynamo.testing import rand_strided
1029
+ from torch._inductor.utils import print_performance
1030
+ primals_1 = rand_strided((2, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:4', dtype=torch.bfloat16)
1031
+ primals_2 = rand_strided((2, 8, 2048, 128), (2097152, 262144, 128, 1), device='cuda:4', dtype=torch.bfloat16)
1032
+ primals_3 = rand_strided((2, 8, 2048, 128), (2097152, 262144, 128, 1), device='cuda:4', dtype=torch.bfloat16)
1033
+ primals_4 = rand_strided((2, 1, 16, 16), (256, 256, 16, 1), device='cuda:4', dtype=torch.int32)
1034
+ primals_5 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:4', dtype=torch.int32)
1035
+ primals_6 = rand_strided((2, ), (1, ), device='cuda:4', dtype=torch.int64)
1036
+ primals_7 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:4', dtype=torch.int32)
1037
+ primals_8 = rand_strided((2, 1, 16, 16), (256, 256, 16, 1), device='cuda:4', dtype=torch.int32)
1038
+ primals_9 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:4', dtype=torch.int32)
1039
+ primals_10 = rand_strided((2, 1, 16, 16), (256, 256, 16, 1), device='cuda:4', dtype=torch.int32)
1040
+ primals_11 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:4', dtype=torch.int32)
1041
+ primals_12 = rand_strided((2, 1, 16, 16), (256, 256, 16, 1), device='cuda:4', dtype=torch.int32)
1042
+ getitem = rand_strided((2, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:4', dtype=torch.bfloat16)
1043
+ getitem_1 = rand_strided((2, 32, 2048), (65536, 2048, 1), device='cuda:4', dtype=torch.float32)
1044
+ tangents_1 = rand_strided((2, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:4', dtype=torch.bfloat16)
1045
+ fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, getitem, getitem_1, tangents_1])
1046
+ return print_performance(fn, times=times, repeat=repeat)
1047
+
1048
+
1049
+ if __name__ == "__main__":
1050
+ from torch._inductor.wrapper_benchmark import compiled_module_main
1051
+ compiled_module_main('None', benchmark_compiled_module)
SpecForge-ext/cache/compiled_kernels/dw/cdwf7pztwx35f2ysnyf6io3giyljdt7efoxairyx6so6kpwdnnl2.py ADDED
@@ -0,0 +1,835 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+
9
+ @triton_heuristics.template(
10
+
11
+ num_stages=3,
12
+ num_warps=8,
13
+ triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'kernel_name': 'triton_tem_fused_zeros_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
15
+
16
+ )
17
+ @triton.jit
18
+ def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3):
19
+ PRESCALE_QK : tl.constexpr = False
20
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
21
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
22
+ WRITE_DQ : tl.constexpr = True
23
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
24
+ OUTPUT_MAX : tl.constexpr = False
25
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
26
+ IS_DIVISIBLE : tl.constexpr = False
27
+ SM_SCALE : tl.constexpr = 0.08838834764831843
28
+ GQA_SHARED_HEADS : tl.constexpr = 4
29
+ HAS_FULL_BLOCKS : tl.constexpr = True
30
+ QK_HEAD_DIM : tl.constexpr = 128
31
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
32
+ V_HEAD_DIM : tl.constexpr = 128
33
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
34
+ SAFE_HEAD_DIM : tl.constexpr = True
35
+ BLOCK_M1 : tl.constexpr = 64
36
+ BLOCK_N1 : tl.constexpr = 128
37
+ BLOCK_M2 : tl.constexpr = 128
38
+ BLOCK_N2 : tl.constexpr = 64
39
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
40
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
41
+ INDEX_DTYPE : tl.constexpr = tl.int32
42
+ Q = arg_Q
43
+ K = arg_K
44
+ V = arg_V
45
+ LSE = arg_LSE
46
+ DELTA = arg_DELTA
47
+ DO = arg_DO
48
+ DQ = arg_DQ
49
+ DV = arg_DV
50
+ KV_NUM_BLKS = arg_KV_NUM_BLKS
51
+ KV_IDX = arg_KV_IDX
52
+ Q_NUM_BLKS = arg_Q_NUM_BLKS
53
+ Q_IDX = arg_Q_IDX
54
+ FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
55
+ FULL_KV_IDX = arg_FULL_KV_IDX
56
+ FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS
57
+ FULL_Q_IDX = arg_FULL_Q_IDX
58
+
59
+ # Sub notation for this kernel:
60
+ #
61
+ # Q: Query, K: Key, V: Value
62
+ # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype)
63
+ # DELTA: Precomputed sum(OUT*DO, axis=-1)
64
+ # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value
65
+ # DK: Derivative of Key, is the written to via the store_output call due to some limitations with
66
+ # inductor codegen
67
+ # M: Number of queries, N: Number of keys/values
68
+ # QK_HEAD_DIM: The dimension of the query and key embeddings
69
+ # V_HEAD_DIM: The dimension of the value embeddings
70
+ # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim
71
+ # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
72
+ # (Modifiable) Performance tuning options
73
+ # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block.
74
+ # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V.
75
+ # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q.
76
+ # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block.
77
+ #
78
+ # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
79
+ # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
80
+ # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
81
+ # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query.
82
+ # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query.
83
+ # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
84
+ # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
85
+ # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query.
86
+ # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query.
87
+
88
+ # The below are kernel options that can be applied for certain score_mods,
89
+ # or involve a numerics vs. perf tradeoff
90
+ # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
91
+ # about 20% more numerical error, but slightly faster.
92
+
93
+ # Define strides of inputs
94
+ stride_qz, stride_qh, stride_qm, stride_qd = 8388608, 128, 4096, 1
95
+ stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks0, 128*ks0, 128, 1
96
+ stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks0, 128*ks0, 128, 1
97
+ stride_doz, stride_doh, stride_dom, stride_dod = 8388608, 262144, 128, 1
98
+
99
+ stride_dqz, stride_dqh, stride_dqm, stride_dqd = 8388608, 128, 4096, 1
100
+ stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks0, 128*ks0, 128, 1
101
+
102
+ ZQ = 2
103
+ HQ = 32
104
+ HKV = 8
105
+ Q_LEN = 2048
106
+ ZKV = 2
107
+ KV_LEN = ks0
108
+
109
+ MATMUL_PRECISION = Q.dtype.element_ty
110
+
111
+ pid = tl.program_id(0).to(INDEX_DTYPE)
112
+ NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1)
113
+ NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2)
114
+
115
+ off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx
116
+ off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx
117
+ off_zkv = off_zq % ZKV # kv batch idx
118
+
119
+ SPARSE_Z = 2
120
+ SPARSE_HQ = 1
121
+
122
+ sparse_idx_z = off_zq % SPARSE_Z
123
+
124
+ k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64)
125
+ v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64)
126
+ # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
127
+ # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
128
+ dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64)
129
+
130
+ # offset K, V, DV pointers for batch/kv-head
131
+ K += k_adj
132
+ V += v_adj
133
+ DV += dv_adj
134
+
135
+ RCP_LN2 = 1.44269504
136
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
137
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
138
+
139
+ if pid >= NUM_KV_BLOCKS:
140
+ off_pid = pid - NUM_KV_BLOCKS
141
+ # THIS BLOCK DOES DQ
142
+ SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2)
143
+ SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
144
+ off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS
145
+ start_m2_block = off_pid % NUM_Q_BLOCKS
146
+ off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE
147
+ stride_kv_num_blks_h = 16
148
+ stride_kv_idx_h = 16*ks1
149
+ stride_kv_idx_m = ks1
150
+
151
+ sparse_idx_hq2 = off_hq2 % SPARSE_HQ
152
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2
153
+
154
+ sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask
155
+ sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950
156
+
157
+ # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
158
+ q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64)
159
+ do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64)
160
+ dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64)
161
+ off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64)
162
+
163
+ Q2 = Q + q_adj2
164
+ DO2 = DO + do_adj2
165
+ # TODO: This does not work if DQ is not the same layout as Q (for example,
166
+ # if Q is broadcasted)
167
+ DQ2 = DQ + dq_adj2
168
+ LSE2 = LSE + off_chz2
169
+ DELTA2 = DELTA + off_chz2
170
+
171
+ # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32)
172
+ dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
173
+
174
+ start_m2 = start_m2_block * BLOCK_M2
175
+ offs_m2 = start_m2 + tl.arange(0, BLOCK_M2)
176
+
177
+ # load Q and do: they stay in SRAM throughout the inner loop.
178
+ q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
179
+ do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
180
+
181
+ if PRESCALE_QK:
182
+ q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
183
+
184
+ if IS_DIVISIBLE:
185
+ Di = tl.load(DELTA2 + offs_m2)
186
+ lse = tl.load(LSE2 + offs_m2)
187
+ else:
188
+ Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN)
189
+ lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN)
190
+ lse = tl.where(lse == -float("inf"), 0.0, lse)
191
+ lse = lse[:, None]
192
+
193
+ # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
194
+ # KV_IDX and KV_NUM_BLKS are always contiguous.
195
+ kv_indices = KV_IDX + sparse_kv_idx_offset
196
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
197
+ sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
198
+
199
+ offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
200
+ dq = bwd_dq_inner(
201
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
202
+ K, V,
203
+ dq, q, do, Di, lse,
204
+ off_zq, off_hq2, offs_m2, offs_n2,
205
+ stride_kn, stride_kd, stride_vn, stride_vd,
206
+ kv_indices, sparse_kv_num_blocks,
207
+ MATMUL_PRECISION,
208
+ IS_FULL_BLOCKS=False,
209
+ )
210
+
211
+ if HAS_FULL_BLOCKS:
212
+ # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
213
+ # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
214
+ kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
215
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
216
+ sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
217
+
218
+ offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
219
+ dq = bwd_dq_inner(
220
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
221
+ K, V,
222
+ dq, q, do, Di, lse,
223
+ off_zq, off_hq2, offs_m2, offs_n2,
224
+ stride_kn, stride_kd, stride_vn, stride_vd,
225
+ kv_indices, sparse_kv_num_blocks,
226
+ MATMUL_PRECISION,
227
+ IS_FULL_BLOCKS=True,
228
+ )
229
+
230
+ # Write back dQ.
231
+ dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd
232
+ dq *= SM_SCALE
233
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
234
+ tl.store(dq_ptrs, dq)
235
+ else:
236
+ tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM))
237
+ else:
238
+ # THIS BLOCK DOES DK & DV
239
+ SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
240
+ SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1)
241
+
242
+ pid_mask = pid // SPARSE_KV_MULTIPLE
243
+
244
+ stride_q_num_blks_h = ks2
245
+ stride_q_idx_h = 16*ks3
246
+ stride_q_idx_n = 16
247
+
248
+
249
+ dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
250
+ dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
251
+
252
+ start_n1 = pid * BLOCK_N1
253
+ offs_n1 = start_n1 + tl.arange(0, BLOCK_N1)
254
+
255
+ # load K and V: they stay in SRAM throughout the inner loop.
256
+ k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
257
+ v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
258
+
259
+ if PRESCALE_QK:
260
+ k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
261
+
262
+ for off_g in range(0, GQA_SHARED_HEADS):
263
+ off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g
264
+
265
+ # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
266
+ q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64)
267
+ do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64)
268
+ dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64)
269
+ off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64)
270
+
271
+ Q1 = Q + q_adj1
272
+ DO1 = DO + do_adj1
273
+ # TODO: This does not work if DQ is not the same layout as Q (for example,
274
+ # if Q is broadcasted)
275
+ LSE1 = LSE + off_chz1
276
+ DELTA1 = DELTA + off_chz1
277
+
278
+ sparse_idx_hq1 = off_hq1 % SPARSE_HQ
279
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1
280
+
281
+ sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask
282
+ sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950
283
+
284
+ # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
285
+ # Q_IDX and Q_NUM_BLKS are always contiguous.
286
+ q_indices = Q_IDX + sparse_q_idx_offset
287
+ q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
288
+ sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset)
289
+
290
+ offs_m1 = q_start + tl.arange(0, BLOCK_M1)
291
+ dk, dv = bwd_dkdv_inner(
292
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
293
+ Q1, DO1, DELTA1, LSE1,
294
+ dk, dv, k, v,
295
+ off_zq, off_hq1, offs_n1, offs_m1,
296
+ stride_qm, stride_qd, stride_dom, stride_dod,
297
+ q_indices, sparse_q_num_blocks,
298
+ MATMUL_PRECISION,
299
+ IS_FULL_BLOCKS=False,
300
+ )
301
+
302
+
303
+ if HAS_FULL_BLOCKS:
304
+ # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
305
+ # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous.
306
+ q_indices = FULL_Q_IDX + sparse_q_idx_offset
307
+ q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
308
+ sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset)
309
+
310
+ offs_m1 = q_start + tl.arange(0, BLOCK_M1)
311
+ dk, dv = bwd_dkdv_inner(
312
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
313
+ Q1, DO1, DELTA1, LSE1,
314
+ dk, dv, k, v,
315
+ off_zq, off_hq1, offs_n1, offs_m1,
316
+ stride_qm, stride_qd, stride_dom, stride_dod,
317
+ q_indices, sparse_q_num_blocks,
318
+ MATMUL_PRECISION,
319
+ IS_FULL_BLOCKS=True,
320
+ )
321
+
322
+ # Write back dV and dK.
323
+ dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd
324
+
325
+ index_n = offs_n1[:, None]
326
+ index_k = offs_k[None, :]
327
+ index_v = offs_v[None, :]
328
+
329
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
330
+ tl.store(dv_ptrs, dv)
331
+ else:
332
+ tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM))
333
+
334
+ dk *= SM_SCALE
335
+
336
+ if SAFE_HEAD_DIM:
337
+ mask = index_n < KV_LEN
338
+ else:
339
+ mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM)
340
+
341
+ # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
342
+ # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
343
+ tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED])
344
+ xindex = index_k + 128*index_n + 128*off_hkv*ks0 + 1024*off_zq*ks0
345
+ tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask)
346
+
347
+ @triton.jit
348
+ def bwd_dq_inner(
349
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
350
+ K, V, # pointers
351
+ dq, q, do, Di, lse,
352
+ off_z, off_hq, offs_m2, offs_n2,
353
+ stride_kn, stride_kd, stride_vn, stride_vd,
354
+ kv_indices, sparse_kv_num_blocks,
355
+ MATMUL_PRECISION,
356
+ IS_FULL_BLOCKS,
357
+ ):
358
+ PRESCALE_QK : tl.constexpr = False
359
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
360
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
361
+ WRITE_DQ : tl.constexpr = True
362
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
363
+ OUTPUT_MAX : tl.constexpr = False
364
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
365
+ IS_DIVISIBLE : tl.constexpr = False
366
+ SM_SCALE : tl.constexpr = 0.08838834764831843
367
+ GQA_SHARED_HEADS : tl.constexpr = 4
368
+ HAS_FULL_BLOCKS : tl.constexpr = True
369
+ QK_HEAD_DIM : tl.constexpr = 128
370
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
371
+ V_HEAD_DIM : tl.constexpr = 128
372
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
373
+ SAFE_HEAD_DIM : tl.constexpr = True
374
+ BLOCK_M1 : tl.constexpr = 64
375
+ BLOCK_N1 : tl.constexpr = 128
376
+ BLOCK_M2 : tl.constexpr = 128
377
+ BLOCK_N2 : tl.constexpr = 64
378
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
379
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
380
+ INDEX_DTYPE : tl.constexpr = tl.int32
381
+
382
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
383
+ RCP_LN2: tl.constexpr = 1.44269504
384
+ Q_LEN = 2048
385
+ KV_LEN = ks0
386
+
387
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
388
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
389
+
390
+ kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd
391
+ vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd
392
+ # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
393
+ tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
394
+
395
+ hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1))
396
+
397
+ for start_n in range(0, hi):
398
+ dq = bwd_dq_block_mn(
399
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
400
+ dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
401
+ off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
402
+ stride_kn, stride_kd, stride_vn, stride_vd,
403
+ kv_indices, sparse_kv_num_blocks,
404
+ MATMUL_PRECISION, RCP_LN2,
405
+ IS_FULL_BLOCKS,
406
+ )
407
+
408
+ # Increment pointers.
409
+ offset = get_offset_for_next_block(
410
+ start_n, kv_indices, sparse_kv_num_blocks,
411
+ SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS
412
+ )
413
+
414
+ kT_ptrs += offset * stride_kn
415
+ vT_ptrs += offset * stride_vn
416
+
417
+ offs_n2 += offset
418
+
419
+ return dq
420
+
421
+
422
+ @triton.jit
423
+ def bwd_dq_block_mn(
424
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
425
+ dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
426
+ off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
427
+ stride_kn, stride_kd, stride_vn, stride_vd,
428
+ kv_indices, sparse_kv_num_blocks,
429
+ MATMUL_PRECISION, RCP_LN2,
430
+ IS_FULL_BLOCKS,
431
+ ):
432
+ PRESCALE_QK : tl.constexpr = False
433
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
434
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
435
+ WRITE_DQ : tl.constexpr = True
436
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
437
+ OUTPUT_MAX : tl.constexpr = False
438
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
439
+ IS_DIVISIBLE : tl.constexpr = False
440
+ SM_SCALE : tl.constexpr = 0.08838834764831843
441
+ GQA_SHARED_HEADS : tl.constexpr = 4
442
+ HAS_FULL_BLOCKS : tl.constexpr = True
443
+ QK_HEAD_DIM : tl.constexpr = 128
444
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
445
+ V_HEAD_DIM : tl.constexpr = 128
446
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
447
+ SAFE_HEAD_DIM : tl.constexpr = True
448
+ BLOCK_M1 : tl.constexpr = 64
449
+ BLOCK_N1 : tl.constexpr = 128
450
+ BLOCK_M2 : tl.constexpr = 128
451
+ BLOCK_N2 : tl.constexpr = 64
452
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
453
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
454
+ INDEX_DTYPE : tl.constexpr = tl.int32
455
+
456
+
457
+ # NB reversed order to since K is transposed
458
+ kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN)
459
+ qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION)
460
+ if not PRESCALE_QK:
461
+ qk *= SM_SCALE
462
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
463
+ pre_mod_scores = qk
464
+ n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None)
465
+ # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim
466
+ # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary
467
+ m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None)
468
+
469
+ tmp0 = (qk)
470
+ post_mod_scores = tmp0
471
+
472
+
473
+
474
+
475
+ if not IS_DIVISIBLE:
476
+ post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf"))
477
+
478
+ if not IS_FULL_BLOCKS:
479
+ tmp1 = tl.full([1], False, tl.int1)
480
+ tmp2 = (m)
481
+ tmp3 = (n)
482
+ tmp4 = tmp2 >= tmp3
483
+ tmp5 = tmp3.to(tl.int64)
484
+ tmp6 = (off_z)
485
+ tmp7 = tl.load(in_ptr16 + tmp6)
486
+ tmp8 = tmp5 < tmp7
487
+ tmp9 = tmp2.to(tl.int64)
488
+ tmp10 = tmp9 < tmp7
489
+ tmp11 = tmp8 & tmp10
490
+ tmp12 = tmp4 & tmp11
491
+ tmp13 = tmp1 | tmp12
492
+ tmp14 = tl.full([1], 2048, tl.int32)
493
+ tmp15 = tmp3 >= tmp14
494
+ tmp16 = (tmp3 % tmp14)
495
+ tmp17 = tl.full([1], 0, tl.int32)
496
+ tmp18 = tmp16 != tmp17
497
+ tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0
498
+ tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0
499
+ tmp21 = tmp19 != tmp20
500
+ tmp22 = tmp18 & tmp21
501
+ tmp23 = tmp16 + tmp14
502
+ tmp24 = tl.where(tmp22, tmp23, tmp16)
503
+ tmp25 = tmp24.to(tl.int64)
504
+ tmp26 = tmp25 < tmp7
505
+ tmp27 = tmp15 & tmp26
506
+ tmp28 = tmp3 - tmp2
507
+ tmp29 = (tmp28 % tmp14)
508
+ tmp30 = tmp29 != tmp17
509
+ tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0
510
+ tmp32 = tmp31 != tmp20
511
+ tmp33 = tmp30 & tmp32
512
+ tmp34 = tmp29 + tmp14
513
+ tmp35 = tl.where(tmp33, tmp34, tmp29)
514
+ tmp36 = tmp35 == tmp17
515
+ tmp37 = tmp27 & tmp36
516
+ tmp38 = tmp13 | tmp37
517
+ mask_mod_output = tmp38
518
+
519
+
520
+ # apply mask for partial masked block
521
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
522
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
523
+ if not PRESCALE_QK:
524
+ post_mod_scores *= RCP_LN2
525
+ p = tl.math.exp2(post_mod_scores - lse)
526
+ # Compute dP and dS.
527
+ # NB reversed order to since V is transposed
528
+ vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN)
529
+
530
+ dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION)
531
+ ds = p * (dp - Di[:, None])
532
+ # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
533
+ tmp39 = (ds)
534
+ grad_scores = tmp39
535
+
536
+
537
+ if not IS_DIVISIBLE:
538
+ grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0)
539
+
540
+ # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
541
+ if WRITE_DQ:
542
+ scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN)
543
+
544
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
545
+ ds = grad_scores
546
+
547
+ if not IS_FULL_BLOCKS:
548
+ # (grads) apply mask for partially unmasked block
549
+ ds = tl.where(mask_mod_output, ds, 0.0)
550
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
551
+ ds = ds.to(MATMUL_PRECISION)
552
+ # Compute dQ.
553
+ dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION)
554
+
555
+ return dq
556
+
557
+
558
+ @triton.jit
559
+ def bwd_dkdv_inner(
560
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
561
+ Q, DO, DELTA, LSE, # pointers
562
+ dk, dv, k, v,
563
+ off_z, off_hq, offs_n1, offs_m1,
564
+ stride_qm, stride_qd, stride_dom, stride_dod,
565
+ q_indices, sparse_q_num_blocks,
566
+ MATMUL_PRECISION,
567
+ IS_FULL_BLOCKS,
568
+ ):
569
+ PRESCALE_QK : tl.constexpr = False
570
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
571
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
572
+ WRITE_DQ : tl.constexpr = True
573
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
574
+ OUTPUT_MAX : tl.constexpr = False
575
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
576
+ IS_DIVISIBLE : tl.constexpr = False
577
+ SM_SCALE : tl.constexpr = 0.08838834764831843
578
+ GQA_SHARED_HEADS : tl.constexpr = 4
579
+ HAS_FULL_BLOCKS : tl.constexpr = True
580
+ QK_HEAD_DIM : tl.constexpr = 128
581
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
582
+ V_HEAD_DIM : tl.constexpr = 128
583
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
584
+ SAFE_HEAD_DIM : tl.constexpr = True
585
+ BLOCK_M1 : tl.constexpr = 64
586
+ BLOCK_N1 : tl.constexpr = 128
587
+ BLOCK_M2 : tl.constexpr = 128
588
+ BLOCK_N2 : tl.constexpr = 64
589
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
590
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
591
+ INDEX_DTYPE : tl.constexpr = tl.int32
592
+
593
+ SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
594
+ RCP_LN2: tl.constexpr = 1.44269504
595
+ Q_LEN = 2048
596
+ KV_LEN = ks0
597
+
598
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
599
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
600
+
601
+ qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd
602
+ do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod
603
+ # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
604
+ tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
605
+
606
+ # The minimum is needed to handle the case where we run with a super large
607
+ # SPARSE_BLOCK_SIZE (i.e. no block-mask!)
608
+ hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1))
609
+
610
+ for start_m in range(0, hi):
611
+ dk, dv = bwd_dkdv_block_mn(
612
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
613
+ dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
614
+ off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
615
+ stride_qm, stride_qd, stride_dom, stride_dod,
616
+ q_indices, sparse_q_num_blocks,
617
+ MATMUL_PRECISION, RCP_LN2,
618
+ IS_FULL_BLOCKS,
619
+ )
620
+ # Increment pointers.
621
+ offset = get_offset_for_next_block(
622
+ start_m, q_indices, sparse_q_num_blocks,
623
+ SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS
624
+ )
625
+
626
+ qT_ptrs += offset * stride_qm
627
+ do_ptrs += offset * stride_dom
628
+ offs_m1 += offset
629
+
630
+ return dk, dv
631
+
632
+
633
+ @triton.jit
634
+ def bwd_dkdv_block_mn(
635
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
636
+ dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
637
+ off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
638
+ stride_qm, stride_qd, stride_dom, stride_dod,
639
+ q_indices, sparse_q_num_blocks,
640
+ MATMUL_PRECISION, RCP_LN2,
641
+ IS_FULL_BLOCKS,
642
+ ):
643
+ PRESCALE_QK : tl.constexpr = False
644
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
645
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
646
+ WRITE_DQ : tl.constexpr = True
647
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
648
+ OUTPUT_MAX : tl.constexpr = False
649
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
650
+ IS_DIVISIBLE : tl.constexpr = False
651
+ SM_SCALE : tl.constexpr = 0.08838834764831843
652
+ GQA_SHARED_HEADS : tl.constexpr = 4
653
+ HAS_FULL_BLOCKS : tl.constexpr = True
654
+ QK_HEAD_DIM : tl.constexpr = 128
655
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
656
+ V_HEAD_DIM : tl.constexpr = 128
657
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
658
+ SAFE_HEAD_DIM : tl.constexpr = True
659
+ BLOCK_M1 : tl.constexpr = 64
660
+ BLOCK_N1 : tl.constexpr = 128
661
+ BLOCK_M2 : tl.constexpr = 128
662
+ BLOCK_N2 : tl.constexpr = 64
663
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
664
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
665
+ INDEX_DTYPE : tl.constexpr = tl.int32
666
+
667
+
668
+ # NB reversed order since Q is transposed
669
+ qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN)
670
+ # Load LSE before computing qk to reduce pipeline stall.
671
+ if IS_DIVISIBLE:
672
+ lse = tl.load(LSE + offs_m1)
673
+ else:
674
+ lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN)
675
+ lse = tl.where(lse == -float("inf"), 0.0, lse)
676
+ qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION)
677
+ if not PRESCALE_QK:
678
+ qkT *= SM_SCALE
679
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
680
+ m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None)
681
+ # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim
682
+ # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary
683
+ n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None)
684
+
685
+ pre_mod_scores = qkT
686
+ tmp40 = (qkT)
687
+ post_mod_scores = tmp40
688
+
689
+
690
+
691
+ if not IS_DIVISIBLE:
692
+ post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf"))
693
+
694
+ if not IS_FULL_BLOCKS:
695
+ tmp41 = tl.full([1], False, tl.int1)
696
+ tmp42 = (m)
697
+ tmp43 = (n)
698
+ tmp44 = tmp42 >= tmp43
699
+ tmp45 = tmp43.to(tl.int64)
700
+ tmp46 = (off_z)
701
+ tmp47 = tl.load(in_ptr16 + tmp46)
702
+ tmp48 = tmp45 < tmp47
703
+ tmp49 = tmp42.to(tl.int64)
704
+ tmp50 = tmp49 < tmp47
705
+ tmp51 = tmp48 & tmp50
706
+ tmp52 = tmp44 & tmp51
707
+ tmp53 = tmp41 | tmp52
708
+ tmp54 = tl.full([1], 2048, tl.int32)
709
+ tmp55 = tmp43 >= tmp54
710
+ tmp56 = (tmp43 % tmp54)
711
+ tmp57 = tl.full([1], 0, tl.int32)
712
+ tmp58 = tmp56 != tmp57
713
+ tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0
714
+ tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0
715
+ tmp61 = tmp59 != tmp60
716
+ tmp62 = tmp58 & tmp61
717
+ tmp63 = tmp56 + tmp54
718
+ tmp64 = tl.where(tmp62, tmp63, tmp56)
719
+ tmp65 = tmp64.to(tl.int64)
720
+ tmp66 = tmp65 < tmp47
721
+ tmp67 = tmp55 & tmp66
722
+ tmp68 = tmp43 - tmp42
723
+ tmp69 = (tmp68 % tmp54)
724
+ tmp70 = tmp69 != tmp57
725
+ tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0
726
+ tmp72 = tmp71 != tmp60
727
+ tmp73 = tmp70 & tmp72
728
+ tmp74 = tmp69 + tmp54
729
+ tmp75 = tl.where(tmp73, tmp74, tmp69)
730
+ tmp76 = tmp75 == tmp57
731
+ tmp77 = tmp67 & tmp76
732
+ tmp78 = tmp53 | tmp77
733
+ mask_mod_output = tmp78
734
+
735
+ # (grads) apply mask for fully masked block
736
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
737
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
738
+ if not PRESCALE_QK:
739
+ post_mod_scores *= RCP_LN2
740
+ pT = tl.math.exp2(post_mod_scores - lse[None, :])
741
+ do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
742
+ # Compute dV.
743
+ ppT = pT
744
+ dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION)
745
+ if IS_DIVISIBLE:
746
+ Di = tl.load(DELTA + offs_m1)
747
+ else:
748
+ Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN)
749
+ # Compute dP and dS.
750
+ dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION)
751
+ dsT = pT * (dpT - Di[None, :])
752
+ # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
753
+ tmp79 = (dsT)
754
+ grad_scores = tmp79
755
+
756
+
757
+
758
+ if not IS_DIVISIBLE:
759
+ grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0)
760
+
761
+ # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
762
+ if not WRITE_DQ:
763
+ idx_b = off_z
764
+ idx_h = off_hq
765
+ idx_m = m
766
+ idx_n = n
767
+ scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN)
768
+
769
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
770
+ dsT = grad_scores
771
+ if not IS_FULL_BLOCKS:
772
+ # (grads) apply mask for partially unmasked block
773
+ dsT = tl.where(mask_mod_output, dsT, 0.0)
774
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
775
+ dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION)
776
+
777
+ return dk, dv
778
+
779
+ # Utility triton funcs
780
+ @triton.jit
781
+ def get_offset_for_next_block(
782
+ loop_iter, col_indices, total_blocks,
783
+ SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
784
+ BLOCKS_ARE_CONTIGUOUS: tl.constexpr
785
+ ):
786
+ if BLOCKS_ARE_CONTIGUOUS:
787
+ return BLOCK
788
+ cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
789
+ cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
790
+ next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
791
+ needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
792
+ jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
793
+ offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
794
+ return offset
795
+
796
+ @triton.jit
797
+ def get_bounded_indices(indices, max_len=None):
798
+ return indices % max_len if max_len is not None else indices
799
+
800
+ @triton.jit
801
+ def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
802
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
803
+ return tl.load(block_ptr)
804
+ elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
805
+ return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
806
+ elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
807
+ return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
808
+ else:
809
+ return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
810
+
811
+ @triton.jit
812
+ def load_checked_2d(
813
+ ptr,
814
+ offs_m,
815
+ offs_n,
816
+ stride_m,
817
+ stride_n,
818
+ IS_DIVISIBLE_M: tl.constexpr,
819
+ IS_DIVISIBLE_N: tl.constexpr,
820
+ M_LEN: tl.constexpr,
821
+ N_LEN: tl.constexpr,
822
+ ):
823
+ # Calculate final pointer if strides are provided
824
+ if stride_m is not None and stride_n is not None:
825
+ ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
826
+
827
+ # Handle all masking cases
828
+ if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
829
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
830
+ elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
831
+ return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
832
+ elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
833
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
834
+ else: # Both divisible
835
+ return tl.load(ptr)
SpecForge-ext/cache/compiled_kernels/dw/cdwxivilyaij5fi345sh6qe7kemmtker7fznljyr22uuhwbwlgsx.py ADDED
@@ -0,0 +1,675 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AOT ID: ['6_forward']
2
+ from ctypes import c_void_p, c_long, c_int
3
+ import torch
4
+ import math
5
+ import random
6
+ import os
7
+ import tempfile
8
+ from math import inf, nan
9
+ from cmath import nanj
10
+ from torch._inductor.hooks import run_intermediate_hooks
11
+ from torch._inductor.utils import maybe_profile
12
+ from torch._inductor.codegen.memory_planning import _align as align
13
+ from torch import device, empty_strided
14
+ from torch._inductor.async_compile import AsyncCompile
15
+ from torch._inductor.select_algorithm import extern_kernels
16
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
17
+ import triton
18
+ import triton.language as tl
19
+ from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
20
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
21
+
22
+ aten = torch.ops.aten
23
+ inductor_ops = torch.ops.inductor
24
+ _quantized = torch.ops._quantized
25
+ assert_size_stride = torch._C._dynamo.guards.assert_size_stride
26
+ assert_alignment = torch._C._dynamo.guards.assert_alignment
27
+ empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
28
+ empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
29
+ empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
30
+ empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
31
+ empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
32
+ reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
33
+ alloc_from_pool = torch.ops.inductor._alloc_from_pool
34
+ async_compile = AsyncCompile()
35
+ empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
36
+
37
+
38
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/j2/cj2u2laawubvef7t5rtvzax6zebordlqljuy3dh5yawyzullirpa.py
39
+ # Topologically Sorted Source Nodes: [flex_attention], Original ATen: []
40
+ # Source node to ATen node mapping:
41
+ # flex_attention => flex_attention
42
+ # Graph fragment:
43
+ # %primals_1 : Tensor "bf16[2, 32, 2048, 128][8388608, 128, 4096, 1]cuda:3" = PlaceHolder[target=primals_1]
44
+ # %primals_2 : Tensor "bf16[2, 8, 2048, 128][2097152, 262144, 128, 1]cuda:3" = PlaceHolder[target=primals_2]
45
+ # %primals_3 : Tensor "bf16[2, 8, 2048, 128][2097152, 262144, 128, 1]cuda:3" = PlaceHolder[target=primals_3]
46
+ # %getitem_1 : Tensor "f32[2, 32, 2048][65536, 2048, 1]cuda:3" = PlaceHolder[target=getitem_1]
47
+ # %buf1 : Tensor "f32[2, 32, 2048][65536, 2048, 1]cuda:3" = PlaceHolder[target=buf1]
48
+ # %primals_5 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:3" = PlaceHolder[target=primals_5]
49
+ # %primals_4 : Tensor "i32[2, 1, 16, 16][256, 256, 16, 1]cuda:3" = PlaceHolder[target=primals_4]
50
+ # %primals_7 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:3" = PlaceHolder[target=primals_7]
51
+ # %primals_8 : Tensor "i32[2, 1, 16, 16][256, 256, 16, 1]cuda:3" = PlaceHolder[target=primals_8]
52
+ # %primals_6 : Tensor "i64[2][1]cuda:3" = PlaceHolder[target=primals_6]
53
+ # %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%primals_1, %primals_2, %primals_3, %sdpa_score0, (2048, 2048, %primals_5, %primals_4, %primals_7, %primals_8, %primals_9, %primals_10, %primals_11, %primals_12, 128, 128, %sdpa_mask0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_6,)), kwargs = {})
54
+ # return %getitem
55
+ triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', '''
56
+ import triton
57
+ import triton.language as tl
58
+
59
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
60
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
61
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
62
+
63
+ @triton_heuristics.template(
64
+
65
+ num_stages=3,
66
+ num_warps=8,
67
+ triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'in_ptr9': '*i64', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]},
68
+ inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
69
+
70
+ )
71
+ @triton.jit
72
+ def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0):
73
+ PRESCALE_QK : tl.constexpr = False
74
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
75
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
76
+ WRITE_DQ : tl.constexpr = True
77
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
78
+ OUTPUT_MAX : tl.constexpr = False
79
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
80
+ IS_DIVISIBLE : tl.constexpr = True
81
+ SM_SCALE : tl.constexpr = 0.08838834764831843
82
+ GQA_SHARED_HEADS : tl.constexpr = 4
83
+ HAS_FULL_BLOCKS : tl.constexpr = True
84
+ QK_HEAD_DIM : tl.constexpr = 128
85
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
86
+ V_HEAD_DIM : tl.constexpr = 128
87
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
88
+ SAFE_HEAD_DIM : tl.constexpr = True
89
+ USE_TMA : tl.constexpr = False
90
+ BLOCK_M : tl.constexpr = 128
91
+ BLOCK_N : tl.constexpr = 64
92
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
93
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
94
+ INDEX_DTYPE : tl.constexpr = tl.int32
95
+ Q = arg_Q
96
+ K = arg_K
97
+ V = arg_V
98
+ LSE = arg_LSE
99
+ MAX = arg_MAX
100
+ KV_NUM_BLKS = arg_KV_NUM_BLKS
101
+ KV_IDX = arg_KV_IDX
102
+ FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
103
+ FULL_KV_IDX = arg_FULL_KV_IDX
104
+
105
+ # Sub notation for this kernel:
106
+ #
107
+ # Q: Query, K: Key, V: Value
108
+ # M: Number of queries, N: Number of keys/values, D: Model dimension
109
+ # QK_HEAD_DIM: The dimension of the query and key embeddings
110
+ # V_HEAD_DIM: The dimension of the value embeddings
111
+ # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head
112
+ # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
113
+ #
114
+ # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
115
+ # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
116
+ # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
117
+ # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
118
+ # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
119
+ #
120
+ # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad
121
+ #
122
+ # (Modifiable) Performance tuning options
123
+ # BLOCK_M: The thread block size across the seqlen dim of Q.
124
+ # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block.
125
+
126
+ # The below are kernel options that can be applied for certain score_mods,
127
+ # or involve a numerics vs. perf tradeoff
128
+ # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
129
+ # about 20% more numerical error, but slightly faster.
130
+ # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row
131
+ # is not masked out? If so, we can skip an extra safety check
132
+ # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are
133
+ # contiguous? If so, we don't need to do an indirect jump for every block
134
+
135
+ tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0)
136
+ tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0)
137
+
138
+ # Define strides of inputs
139
+ stride_qz, stride_qh, stride_qm, stride_qk = 8388608, 128, 4096, 1
140
+ stride_kz, stride_kh, stride_kn, stride_kk = 2097152, 262144, 128, 1
141
+ stride_vz, stride_vh, stride_vn, stride_vk = 2097152, 262144, 128, 1
142
+
143
+ ZQ = 2
144
+ HQ = 32
145
+ Q_LEN = 2048
146
+ ZKV = 2
147
+ KV_LEN = 2048
148
+
149
+ MATMUL_PRECISION = Q.dtype.element_ty
150
+
151
+ q_start = tl.program_id(0).to(INDEX_DTYPE)
152
+ off_zq = tl.program_id(1).to(INDEX_DTYPE)
153
+ off_hq = tl.program_id(2).to(INDEX_DTYPE)
154
+
155
+ # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq.
156
+ # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0.
157
+ off_zkv = off_zq % ZKV
158
+ off_hkv = off_hq // GQA_SHARED_HEADS
159
+ off_g = off_hq % GQA_SHARED_HEADS
160
+
161
+ q_offset = off_zq * stride_qz + off_hq * stride_qh
162
+ k_offset = off_zkv * stride_kz + off_hkv * stride_kh
163
+ v_offset = off_zkv * stride_vz + off_hkv * stride_vh
164
+
165
+ Q = Q + q_offset
166
+ K = K + k_offset
167
+ V = V + v_offset
168
+
169
+ # Setting up the TMA descriptors for Q, K, V
170
+ desc_q = None
171
+ desc_k = None
172
+ desc_v = None
173
+
174
+ SPARSE_Z = 2
175
+ SPARSE_HQ = 1
176
+
177
+ sparse_idx_z = off_zq % SPARSE_Z
178
+ sparse_idx_hq = off_hq % SPARSE_HQ
179
+
180
+ SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M)
181
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
182
+
183
+ stride_kv_num_blks_h = 16
184
+ stride_kv_idx_h = 256
185
+ stride_kv_idx_m = 16
186
+
187
+ # initialize pointer to m and l
188
+ m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
189
+ l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
190
+ acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
191
+
192
+ offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
193
+
194
+ # KV_IDX and KV_NUM_BLKS are always contiguous.
195
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq
196
+ sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE
197
+ sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950
198
+ offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
199
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
200
+ q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
201
+
202
+ # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
203
+ # We don't know anything "special" about these blocks, so we need to apply
204
+ # both score_mod and mask_mod to it
205
+ kv_indices = KV_IDX + sparse_kv_idx_offset
206
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
207
+ kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
208
+ block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
209
+
210
+
211
+ # K and V pointers will be passed directly to forward_inner
212
+
213
+ offs_n = kv_start + tl.arange(0, BLOCK_N)
214
+
215
+
216
+ acc, l_i, m_i = forward_inner(
217
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0,
218
+ q, K, V,
219
+ desc_k, desc_v, Q_LEN, KV_LEN,
220
+ acc, l_i, m_i,
221
+ off_zq, off_hq, offs_m[:, None], offs_n[None, :],
222
+ kv_start,
223
+ kv_indices, kv_num_blocks,
224
+ 0, block_n_end,
225
+ MATMUL_PRECISION,
226
+ stride_kk, stride_kn, stride_vn, stride_vk,
227
+ IS_FULL_BLOCKS=False,
228
+ )
229
+
230
+ # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
231
+ # We know these blocks are guaranteed to be "full", so we don't need to
232
+ # apply mask_mod to them - only score_mod
233
+ if HAS_FULL_BLOCKS:
234
+ # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
235
+ kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
236
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
237
+ kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
238
+ block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
239
+ # K and V pointers will be passed directly to forward_inner
240
+ offs_n = kv_start + tl.arange(0, BLOCK_N)
241
+
242
+ acc, l_i, m_i = forward_inner(
243
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0,
244
+ q, K, V,
245
+ desc_k, desc_v, Q_LEN, KV_LEN,
246
+ acc, l_i, m_i,
247
+ off_zq, off_hq, offs_m[:, None], offs_n[None, :],
248
+ kv_start,
249
+ kv_indices, kv_num_blocks,
250
+ 0, block_n_end,
251
+ MATMUL_PRECISION,
252
+ stride_kk, stride_kn, stride_vn, stride_vk,
253
+ IS_FULL_BLOCKS=True,
254
+ )
255
+
256
+
257
+ # [Note] Handle fully masked out rows:
258
+ # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf.
259
+ # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step
260
+ l_i = tl.where(l_i == 0.0, 1, l_i)
261
+
262
+ acc = acc / l_i[:, None]
263
+ idx_zq = tl.program_id(1).to(INDEX_DTYPE)
264
+ idx_hq = tl.program_id(2).to(INDEX_DTYPE)
265
+ idx_m = offs_m[:, None].to(INDEX_DTYPE)
266
+ idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE)
267
+
268
+ mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM)
269
+
270
+ tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED])
271
+ xindex = idx_d + 128*idx_m + 262144*idx_hq + 8388608*idx_zq
272
+ tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m + 8388608*idx_zq, acc.shape)), acc, mask)
273
+
274
+ if OUTPUT_LOGSUMEXP:
275
+ off_hz = off_zq * HQ + off_hq
276
+ l_ptrs = LSE + off_hz * Q_LEN + offs_m
277
+ lse = m_i + tl.math.log2(l_i)
278
+ if IS_DIVISIBLE:
279
+ tl.store(l_ptrs, lse)
280
+ else:
281
+ tl.store(l_ptrs, lse, mask=offs_m < Q_LEN)
282
+
283
+ if OUTPUT_MAX:
284
+ off_hz = off_zq * HQ + off_hq
285
+ max_ptrs = MAX + off_hz * Q_LEN + offs_m
286
+ if IS_DIVISIBLE:
287
+ tl.store(max_ptrs, m_i)
288
+ else:
289
+ tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN)
290
+
291
+
292
+ # Utility triton funcs
293
+ @triton.jit
294
+ def get_offset_for_next_block(
295
+ loop_iter, col_indices, total_blocks,
296
+ SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
297
+ BLOCKS_ARE_CONTIGUOUS: tl.constexpr
298
+ ):
299
+ if BLOCKS_ARE_CONTIGUOUS:
300
+ return BLOCK
301
+ cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
302
+ cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
303
+ next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
304
+ needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
305
+ jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
306
+ offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
307
+ return offset
308
+
309
+ @triton.jit
310
+ def get_bounded_indices(indices, max_len=None):
311
+ return indices % max_len if max_len is not None else indices
312
+
313
+ @triton.jit
314
+ def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
315
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
316
+ return tl.load(block_ptr)
317
+ elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
318
+ return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
319
+ elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
320
+ return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
321
+ else:
322
+ return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
323
+
324
+ @triton.jit
325
+ def load_checked_2d(
326
+ ptr,
327
+ offs_m,
328
+ offs_n,
329
+ stride_m,
330
+ stride_n,
331
+ IS_DIVISIBLE_M: tl.constexpr,
332
+ IS_DIVISIBLE_N: tl.constexpr,
333
+ M_LEN: tl.constexpr,
334
+ N_LEN: tl.constexpr,
335
+ ):
336
+ # Calculate final pointer if strides are provided
337
+ if stride_m is not None and stride_n is not None:
338
+ ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
339
+
340
+ # Handle all masking cases
341
+ if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
342
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
343
+ elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
344
+ return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
345
+ elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
346
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
347
+ else: # Both divisible
348
+ return tl.load(ptr)
349
+
350
+
351
+ # Common Imports
352
+ @triton.jit
353
+ def forward_block_mn(
354
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0,
355
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
356
+ # accumulated values
357
+ acc, l_i, m_i,
358
+ # Offsets
359
+ off_z, off_h, offs_m, offs_n,
360
+ # Offsets needed for TMA loads
361
+ kv_start,
362
+ kv_offset,
363
+ MATMUL_PRECISION, RCP_LN2,
364
+ # Strides for K and V
365
+ stride_kk, stride_kn, stride_vn, stride_vk,
366
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
367
+
368
+ ):
369
+ # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
370
+ PRESCALE_QK : tl.constexpr = False
371
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
372
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
373
+ WRITE_DQ : tl.constexpr = True
374
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
375
+ OUTPUT_MAX : tl.constexpr = False
376
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
377
+ IS_DIVISIBLE : tl.constexpr = True
378
+ SM_SCALE : tl.constexpr = 0.08838834764831843
379
+ GQA_SHARED_HEADS : tl.constexpr = 4
380
+ HAS_FULL_BLOCKS : tl.constexpr = True
381
+ QK_HEAD_DIM : tl.constexpr = 128
382
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
383
+ V_HEAD_DIM : tl.constexpr = 128
384
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
385
+ SAFE_HEAD_DIM : tl.constexpr = True
386
+ USE_TMA : tl.constexpr = False
387
+ BLOCK_M : tl.constexpr = 128
388
+ BLOCK_N : tl.constexpr = 64
389
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
390
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
391
+ INDEX_DTYPE : tl.constexpr = tl.int32
392
+
393
+
394
+ # -- load k --
395
+ # NB reversed order to since K is transposed
396
+ kv_base_offset = kv_start + kv_offset
397
+
398
+ # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N]
399
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
400
+ offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N)
401
+ k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
402
+
403
+ k = tl.trans(k)
404
+ # -- compute qk ---
405
+ qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2.
406
+ if not PRESCALE_QK:
407
+ qk *= SM_SCALE
408
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
409
+ # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements,
410
+ # which is larger than the actual number of elements. To avoid access memory out of bound,
411
+ # we need to mask out the elements that are out of Q_LEN & KV_LEN.
412
+ m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None)
413
+ n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None)
414
+
415
+ tmp0 = (qk)
416
+ post_mod_scores = tmp0
417
+
418
+
419
+ if CHECK_BLOCK_BOUNDARY:
420
+ # Mask out the elements that are out of the KV_LEN for non divisible seqlen.
421
+ post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf"))
422
+
423
+ if not IS_FULL_BLOCKS:
424
+ tmp1 = tl.full([1], False, tl.int1)
425
+ tmp2 = (m)
426
+ tmp3 = (n)
427
+ tmp4 = tmp2 >= tmp3
428
+ tmp5 = tmp3.to(tl.int64)
429
+ tmp6 = (off_z)
430
+ tmp7 = tl.load(in_ptr9 + tmp6)
431
+ tmp8 = tmp5 < tmp7
432
+ tmp9 = tmp2.to(tl.int64)
433
+ tmp10 = tmp9 < tmp7
434
+ tmp11 = tmp8 & tmp10
435
+ tmp12 = tmp4 & tmp11
436
+ tmp13 = tmp1 | tmp12
437
+ tmp14 = tl.full([1], 2048, tl.int32)
438
+ tmp15 = tmp3 >= tmp14
439
+ tmp16 = (tmp3 % tmp14)
440
+ tmp17 = tl.full([1], 0, tl.int32)
441
+ tmp18 = tmp16 != tmp17
442
+ tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0
443
+ tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0
444
+ tmp21 = tmp19 != tmp20
445
+ tmp22 = tmp18 & tmp21
446
+ tmp23 = tmp16 + tmp14
447
+ tmp24 = tl.where(tmp22, tmp23, tmp16)
448
+ tmp25 = tmp24.to(tl.int64)
449
+ tmp26 = tmp25 < tmp7
450
+ tmp27 = tmp15 & tmp26
451
+ tmp28 = tmp3 - tmp2
452
+ tmp29 = (tmp28 % tmp14)
453
+ tmp30 = tmp29 != tmp17
454
+ tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0
455
+ tmp32 = tmp31 != tmp20
456
+ tmp33 = tmp30 & tmp32
457
+ tmp34 = tmp29 + tmp14
458
+ tmp35 = tl.where(tmp33, tmp34, tmp29)
459
+ tmp36 = tmp35 == tmp17
460
+ tmp37 = tmp27 & tmp36
461
+ tmp38 = tmp13 | tmp37
462
+ mask_mod_output = tmp38
463
+
464
+
465
+ if CHECK_BLOCK_BOUNDARY:
466
+ mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False)
467
+ # apply mask for partially unmasked blocks
468
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
469
+
470
+ if not PRESCALE_QK:
471
+ post_mod_scores *= RCP_LN2
472
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
473
+
474
+ # -- compute scaling constant ---
475
+ m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1))
476
+ if not ROWS_GUARANTEED_SAFE:
477
+ masked_out_rows = (m_ij == float("-inf"))
478
+ m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
479
+ else:
480
+ m_ij_masked = m_ij
481
+
482
+ alpha = tl.math.exp2(m_i - m_ij_masked)
483
+ p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None])
484
+
485
+ # NB: l_i update is pulled up here since it's a bit faster
486
+ # NB: For headdim=256, it's faster to move it back down to after m_i =
487
+ # m_ij
488
+ l_i = l_i * alpha + tl.sum(p, 1)
489
+ # # -- scale and update acc --
490
+ acc = acc * alpha[:, None]
491
+ # Calculate offsets for V loading - reuse kv_base_offset from K loading
492
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
493
+ v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
494
+ acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION)
495
+
496
+ # -- update m_i
497
+ m_i = m_ij
498
+
499
+ return acc, l_i, m_i
500
+
501
+ @triton.jit
502
+ def forward_inner(
503
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0,
504
+ q, K, V,
505
+ desc_k, desc_v, Q_LEN, KV_LEN,
506
+ # accumulated values
507
+ acc, l_i, m_i,
508
+ # Offsets used as inputs to score_mod & mask_mod
509
+ # of size [BLOCK_M, BLOCK_N] or scalar.
510
+ off_z, off_h, offs_m, offs_n,
511
+ # Offsets needed for TMA loads
512
+ kv_start,
513
+ # blocksparse data
514
+ kv_indices, kv_num_blocks,
515
+ # start kv and end kv block
516
+ block_n_start, block_n_end,
517
+ MATMUL_PRECISION,
518
+ # Strides for K and V
519
+ stride_kk, stride_kn, stride_vn, stride_vk,
520
+ IS_FULL_BLOCKS,
521
+ ):
522
+ # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
523
+ PRESCALE_QK : tl.constexpr = False
524
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
525
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
526
+ WRITE_DQ : tl.constexpr = True
527
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
528
+ OUTPUT_MAX : tl.constexpr = False
529
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
530
+ IS_DIVISIBLE : tl.constexpr = True
531
+ SM_SCALE : tl.constexpr = 0.08838834764831843
532
+ GQA_SHARED_HEADS : tl.constexpr = 4
533
+ HAS_FULL_BLOCKS : tl.constexpr = True
534
+ QK_HEAD_DIM : tl.constexpr = 128
535
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
536
+ V_HEAD_DIM : tl.constexpr = 128
537
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
538
+ SAFE_HEAD_DIM : tl.constexpr = True
539
+ USE_TMA : tl.constexpr = False
540
+ BLOCK_M : tl.constexpr = 128
541
+ BLOCK_N : tl.constexpr = 64
542
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
543
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
544
+ INDEX_DTYPE : tl.constexpr = tl.int32
545
+
546
+
547
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
548
+ RCP_LN2: tl.constexpr = 1.44269504
549
+
550
+ if PRESCALE_QK:
551
+ q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
552
+
553
+ kv_offset = 0
554
+
555
+ # loop over k, v and update accumulator until block_n_end
556
+ for start_n in range(block_n_start, block_n_end):
557
+ # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention.
558
+ if IS_DIVISIBLE:
559
+ acc, l_i, m_i = forward_block_mn(
560
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0,
561
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
562
+ # accumulated values
563
+ acc, l_i, m_i,
564
+ # Offsets
565
+ off_z, off_h, offs_m, offs_n,
566
+ # Offsets needed for TMA loads
567
+ kv_start,
568
+ kv_offset,
569
+ MATMUL_PRECISION, RCP_LN2,
570
+ # Strides for K and V
571
+ stride_kk, stride_kn, stride_vn, stride_vk,
572
+ IS_FULL_BLOCKS,
573
+ )
574
+ else:
575
+ # Benchmark shows even we applied mod & mask to each block for non divisible seqlen,
576
+ # it's on par or slightly faster than only applying to the last block in fwd.
577
+ # However, we choose different strategy for bwd, where we only apply mod & mask
578
+ # to the last block because it's faster a lot.
579
+ acc, l_i, m_i = forward_block_mn(
580
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0,
581
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
582
+ # accumulated values
583
+ acc, l_i, m_i,
584
+ # Offsets
585
+ off_z, off_h, offs_m, offs_n,
586
+ # Offsets needed for TMA loads
587
+ kv_start,
588
+ kv_offset,
589
+ MATMUL_PRECISION, RCP_LN2,
590
+ # Strides for K and V
591
+ stride_kk, stride_kn, stride_vn, stride_vk,
592
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
593
+ )
594
+
595
+
596
+
597
+ offset = get_offset_for_next_block(
598
+ start_n, kv_indices, kv_num_blocks,
599
+ SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS
600
+ )
601
+
602
+ offs_n = offs_n + offset
603
+ kv_offset += offset
604
+
605
+
606
+ return acc, l_i, m_i
607
+ ''', device_str='cuda')
608
+
609
+
610
+ async_compile.wait(globals())
611
+ del async_compile
612
+
613
+ class Runner:
614
+ def __init__(self, partitions):
615
+ self.partitions = partitions
616
+
617
+ def recursively_apply_fns(self, fns):
618
+ new_callables = []
619
+ for fn, c in zip(fns, self.partitions):
620
+ new_callables.append(fn(c))
621
+ self.partitions = new_callables
622
+
623
+ def call(self, args):
624
+ primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12 = args
625
+ args.clear()
626
+ assert_size_stride(primals_1, (2, 32, 2048, 128), (8388608, 128, 4096, 1))
627
+ assert_size_stride(primals_2, (2, 8, 2048, 128), (2097152, 262144, 128, 1))
628
+ assert_size_stride(primals_3, (2, 8, 2048, 128), (2097152, 262144, 128, 1))
629
+ assert_size_stride(primals_4, (2, 1, 16, 16), (256, 256, 16, 1))
630
+ assert_size_stride(primals_5, (2, 1, 16), (16, 16, 1))
631
+ assert_size_stride(primals_6, (2, ), (1, ))
632
+ assert_size_stride(primals_7, (2, 1, 16), (16, 16, 1))
633
+ assert_size_stride(primals_8, (2, 1, 16, 16), (256, 256, 16, 1))
634
+ assert_size_stride(primals_9, (2, 1, 16), (16, 16, 1))
635
+ assert_size_stride(primals_10, (2, 1, 16, 16), (256, 256, 16, 1))
636
+ assert_size_stride(primals_11, (2, 1, 16), (16, 16, 1))
637
+ assert_size_stride(primals_12, (2, 1, 16, 16), (256, 256, 16, 1))
638
+ with torch.cuda._DeviceGuard(3):
639
+ torch.cuda.set_device(3)
640
+ buf0 = empty_strided_cuda((2, 32, 2048), (65536, 2048, 1), torch.float32)
641
+ buf1 = empty_strided_cuda((2, 32, 2048), (65536, 2048, 1), torch.float32)
642
+ buf2 = empty_strided_cuda((2, 32, 2048, 128), (8388608, 128, 4096, 1), torch.bfloat16)
643
+ # Topologically Sorted Source Nodes: [flex_attention], Original ATen: []
644
+ stream3 = get_raw_stream(3)
645
+ triton_tem_fused_0.run(primals_1, primals_2, primals_3, buf0, buf1, primals_5, primals_4, primals_7, primals_8, primals_6, buf2, 16, 2, 32, stream=stream3)
646
+ del buf1
647
+ return (buf2, primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, buf2, buf0, )
648
+
649
+ runner = Runner(partitions=[])
650
+ call = runner.call
651
+ recursively_apply_fns = runner.recursively_apply_fns
652
+
653
+
654
+ def benchmark_compiled_module(times=10, repeat=10):
655
+ from torch._dynamo.testing import rand_strided
656
+ from torch._inductor.utils import print_performance
657
+ primals_1 = rand_strided((2, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:3', dtype=torch.bfloat16)
658
+ primals_2 = rand_strided((2, 8, 2048, 128), (2097152, 262144, 128, 1), device='cuda:3', dtype=torch.bfloat16)
659
+ primals_3 = rand_strided((2, 8, 2048, 128), (2097152, 262144, 128, 1), device='cuda:3', dtype=torch.bfloat16)
660
+ primals_4 = rand_strided((2, 1, 16, 16), (256, 256, 16, 1), device='cuda:3', dtype=torch.int32)
661
+ primals_5 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:3', dtype=torch.int32)
662
+ primals_6 = rand_strided((2, ), (1, ), device='cuda:3', dtype=torch.int64)
663
+ primals_7 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:3', dtype=torch.int32)
664
+ primals_8 = rand_strided((2, 1, 16, 16), (256, 256, 16, 1), device='cuda:3', dtype=torch.int32)
665
+ primals_9 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:3', dtype=torch.int32)
666
+ primals_10 = rand_strided((2, 1, 16, 16), (256, 256, 16, 1), device='cuda:3', dtype=torch.int32)
667
+ primals_11 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:3', dtype=torch.int32)
668
+ primals_12 = rand_strided((2, 1, 16, 16), (256, 256, 16, 1), device='cuda:3', dtype=torch.int32)
669
+ fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12])
670
+ return print_performance(fn, times=times, repeat=repeat)
671
+
672
+
673
+ if __name__ == "__main__":
674
+ from torch._inductor.wrapper_benchmark import compiled_module_main
675
+ compiled_module_main('None', benchmark_compiled_module)
SpecForge-ext/cache/compiled_kernels/e6/ce6g3e5xikzaf3a5wmxill5os7magq3p3hzz7uw37za4jjui6tk6.py ADDED
@@ -0,0 +1,552 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+
9
+ @triton_heuristics.template(
10
+
11
+ num_stages=3,
12
+ num_warps=8,
13
+ triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'in_ptr9': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
15
+
16
+ )
17
+ @triton.jit
18
+ def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1):
19
+ PRESCALE_QK : tl.constexpr = False
20
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
21
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
22
+ WRITE_DQ : tl.constexpr = True
23
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
24
+ OUTPUT_MAX : tl.constexpr = False
25
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
26
+ IS_DIVISIBLE : tl.constexpr = False
27
+ SM_SCALE : tl.constexpr = 0.08838834764831843
28
+ GQA_SHARED_HEADS : tl.constexpr = 4
29
+ HAS_FULL_BLOCKS : tl.constexpr = True
30
+ QK_HEAD_DIM : tl.constexpr = 128
31
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
32
+ V_HEAD_DIM : tl.constexpr = 128
33
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
34
+ SAFE_HEAD_DIM : tl.constexpr = True
35
+ USE_TMA : tl.constexpr = False
36
+ BLOCK_M : tl.constexpr = 128
37
+ BLOCK_N : tl.constexpr = 64
38
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
39
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
40
+ INDEX_DTYPE : tl.constexpr = tl.int32
41
+ Q = arg_Q
42
+ K = arg_K
43
+ V = arg_V
44
+ LSE = arg_LSE
45
+ MAX = arg_MAX
46
+ KV_NUM_BLKS = arg_KV_NUM_BLKS
47
+ KV_IDX = arg_KV_IDX
48
+ FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
49
+ FULL_KV_IDX = arg_FULL_KV_IDX
50
+
51
+ # Sub notation for this kernel:
52
+ #
53
+ # Q: Query, K: Key, V: Value
54
+ # M: Number of queries, N: Number of keys/values, D: Model dimension
55
+ # QK_HEAD_DIM: The dimension of the query and key embeddings
56
+ # V_HEAD_DIM: The dimension of the value embeddings
57
+ # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head
58
+ # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
59
+ #
60
+ # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
61
+ # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
62
+ # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
63
+ # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
64
+ # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
65
+ #
66
+ # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad
67
+ #
68
+ # (Modifiable) Performance tuning options
69
+ # BLOCK_M: The thread block size across the seqlen dim of Q.
70
+ # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block.
71
+
72
+ # The below are kernel options that can be applied for certain score_mods,
73
+ # or involve a numerics vs. perf tradeoff
74
+ # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
75
+ # about 20% more numerical error, but slightly faster.
76
+ # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row
77
+ # is not masked out? If so, we can skip an extra safety check
78
+ # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are
79
+ # contiguous? If so, we don't need to do an indirect jump for every block
80
+
81
+ tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0)
82
+ tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0)
83
+
84
+ # Define strides of inputs
85
+ stride_qz, stride_qh, stride_qm, stride_qk = 8388608, 128, 4096, 1
86
+ stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks0, 128*ks0, 128, 1
87
+ stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks0, 128*ks0, 128, 1
88
+
89
+ ZQ = 8
90
+ HQ = 32
91
+ Q_LEN = 2048
92
+ ZKV = 8
93
+ KV_LEN = ks0
94
+
95
+ MATMUL_PRECISION = Q.dtype.element_ty
96
+
97
+ q_start = tl.program_id(0).to(INDEX_DTYPE)
98
+ off_zq = tl.program_id(1).to(INDEX_DTYPE)
99
+ off_hq = tl.program_id(2).to(INDEX_DTYPE)
100
+
101
+ # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq.
102
+ # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0.
103
+ off_zkv = off_zq % ZKV
104
+ off_hkv = off_hq // GQA_SHARED_HEADS
105
+ off_g = off_hq % GQA_SHARED_HEADS
106
+
107
+ q_offset = off_zq * stride_qz + off_hq * stride_qh
108
+ k_offset = off_zkv * stride_kz + off_hkv * stride_kh
109
+ v_offset = off_zkv * stride_vz + off_hkv * stride_vh
110
+
111
+ Q = Q + q_offset
112
+ K = K + k_offset
113
+ V = V + v_offset
114
+
115
+ # Setting up the TMA descriptors for Q, K, V
116
+ desc_q = None
117
+ desc_k = None
118
+ desc_v = None
119
+
120
+ SPARSE_Z = 8
121
+ SPARSE_HQ = 1
122
+
123
+ sparse_idx_z = off_zq % SPARSE_Z
124
+ sparse_idx_hq = off_hq % SPARSE_HQ
125
+
126
+ SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M)
127
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
128
+
129
+ stride_kv_num_blks_h = 16
130
+ stride_kv_idx_h = 16*ks1
131
+ stride_kv_idx_m = ks1
132
+
133
+ # initialize pointer to m and l
134
+ m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
135
+ l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
136
+ acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
137
+
138
+ offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
139
+
140
+ # KV_IDX and KV_NUM_BLKS are always contiguous.
141
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq
142
+ sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE
143
+ sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950
144
+ offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
145
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
146
+ q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
147
+
148
+ # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
149
+ # We don't know anything "special" about these blocks, so we need to apply
150
+ # both score_mod and mask_mod to it
151
+ kv_indices = KV_IDX + sparse_kv_idx_offset
152
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
153
+ kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
154
+ block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
155
+
156
+
157
+ # K and V pointers will be passed directly to forward_inner
158
+
159
+ offs_n = kv_start + tl.arange(0, BLOCK_N)
160
+
161
+
162
+ acc, l_i, m_i = forward_inner(
163
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1,
164
+ q, K, V,
165
+ desc_k, desc_v, Q_LEN, KV_LEN,
166
+ acc, l_i, m_i,
167
+ off_zq, off_hq, offs_m[:, None], offs_n[None, :],
168
+ kv_start,
169
+ kv_indices, kv_num_blocks,
170
+ 0, block_n_end,
171
+ MATMUL_PRECISION,
172
+ stride_kk, stride_kn, stride_vn, stride_vk,
173
+ IS_FULL_BLOCKS=False,
174
+ )
175
+
176
+ # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
177
+ # We know these blocks are guaranteed to be "full", so we don't need to
178
+ # apply mask_mod to them - only score_mod
179
+ if HAS_FULL_BLOCKS:
180
+ # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
181
+ kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
182
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
183
+ kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
184
+ block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
185
+ # K and V pointers will be passed directly to forward_inner
186
+ offs_n = kv_start + tl.arange(0, BLOCK_N)
187
+
188
+ acc, l_i, m_i = forward_inner(
189
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1,
190
+ q, K, V,
191
+ desc_k, desc_v, Q_LEN, KV_LEN,
192
+ acc, l_i, m_i,
193
+ off_zq, off_hq, offs_m[:, None], offs_n[None, :],
194
+ kv_start,
195
+ kv_indices, kv_num_blocks,
196
+ 0, block_n_end,
197
+ MATMUL_PRECISION,
198
+ stride_kk, stride_kn, stride_vn, stride_vk,
199
+ IS_FULL_BLOCKS=True,
200
+ )
201
+
202
+
203
+ # [Note] Handle fully masked out rows:
204
+ # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf.
205
+ # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step
206
+ l_i = tl.where(l_i == 0.0, 1, l_i)
207
+
208
+ acc = acc / l_i[:, None]
209
+ idx_zq = tl.program_id(1).to(INDEX_DTYPE)
210
+ idx_hq = tl.program_id(2).to(INDEX_DTYPE)
211
+ idx_m = offs_m[:, None].to(INDEX_DTYPE)
212
+ idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE)
213
+
214
+ mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM)
215
+
216
+ tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED])
217
+ xindex = idx_d + 128*idx_m + 262144*idx_hq + 8388608*idx_zq
218
+ tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m + 8388608*idx_zq, acc.shape)), acc, mask)
219
+
220
+ if OUTPUT_LOGSUMEXP:
221
+ off_hz = off_zq * HQ + off_hq
222
+ l_ptrs = LSE + off_hz * Q_LEN + offs_m
223
+ lse = m_i + tl.math.log2(l_i)
224
+ if IS_DIVISIBLE:
225
+ tl.store(l_ptrs, lse)
226
+ else:
227
+ tl.store(l_ptrs, lse, mask=offs_m < Q_LEN)
228
+
229
+ if OUTPUT_MAX:
230
+ off_hz = off_zq * HQ + off_hq
231
+ max_ptrs = MAX + off_hz * Q_LEN + offs_m
232
+ if IS_DIVISIBLE:
233
+ tl.store(max_ptrs, m_i)
234
+ else:
235
+ tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN)
236
+
237
+
238
+ # Utility triton funcs
239
+ @triton.jit
240
+ def get_offset_for_next_block(
241
+ loop_iter, col_indices, total_blocks,
242
+ SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
243
+ BLOCKS_ARE_CONTIGUOUS: tl.constexpr
244
+ ):
245
+ if BLOCKS_ARE_CONTIGUOUS:
246
+ return BLOCK
247
+ cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
248
+ cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
249
+ next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
250
+ needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
251
+ jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
252
+ offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
253
+ return offset
254
+
255
+ @triton.jit
256
+ def get_bounded_indices(indices, max_len=None):
257
+ return indices % max_len if max_len is not None else indices
258
+
259
+ @triton.jit
260
+ def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
261
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
262
+ return tl.load(block_ptr)
263
+ elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
264
+ return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
265
+ elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
266
+ return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
267
+ else:
268
+ return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
269
+
270
+ @triton.jit
271
+ def load_checked_2d(
272
+ ptr,
273
+ offs_m,
274
+ offs_n,
275
+ stride_m,
276
+ stride_n,
277
+ IS_DIVISIBLE_M: tl.constexpr,
278
+ IS_DIVISIBLE_N: tl.constexpr,
279
+ M_LEN: tl.constexpr,
280
+ N_LEN: tl.constexpr,
281
+ ):
282
+ # Calculate final pointer if strides are provided
283
+ if stride_m is not None and stride_n is not None:
284
+ ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
285
+
286
+ # Handle all masking cases
287
+ if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
288
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
289
+ elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
290
+ return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
291
+ elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
292
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
293
+ else: # Both divisible
294
+ return tl.load(ptr)
295
+
296
+
297
+ # Common Imports
298
+ @triton.jit
299
+ def forward_block_mn(
300
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1,
301
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
302
+ # accumulated values
303
+ acc, l_i, m_i,
304
+ # Offsets
305
+ off_z, off_h, offs_m, offs_n,
306
+ # Offsets needed for TMA loads
307
+ kv_start,
308
+ kv_offset,
309
+ MATMUL_PRECISION, RCP_LN2,
310
+ # Strides for K and V
311
+ stride_kk, stride_kn, stride_vn, stride_vk,
312
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
313
+
314
+ ):
315
+ # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
316
+ PRESCALE_QK : tl.constexpr = False
317
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
318
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
319
+ WRITE_DQ : tl.constexpr = True
320
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
321
+ OUTPUT_MAX : tl.constexpr = False
322
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
323
+ IS_DIVISIBLE : tl.constexpr = False
324
+ SM_SCALE : tl.constexpr = 0.08838834764831843
325
+ GQA_SHARED_HEADS : tl.constexpr = 4
326
+ HAS_FULL_BLOCKS : tl.constexpr = True
327
+ QK_HEAD_DIM : tl.constexpr = 128
328
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
329
+ V_HEAD_DIM : tl.constexpr = 128
330
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
331
+ SAFE_HEAD_DIM : tl.constexpr = True
332
+ USE_TMA : tl.constexpr = False
333
+ BLOCK_M : tl.constexpr = 128
334
+ BLOCK_N : tl.constexpr = 64
335
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
336
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
337
+ INDEX_DTYPE : tl.constexpr = tl.int32
338
+
339
+
340
+ # -- load k --
341
+ # NB reversed order to since K is transposed
342
+ kv_base_offset = kv_start + kv_offset
343
+
344
+ # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N]
345
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
346
+ offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N)
347
+ k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
348
+
349
+ k = tl.trans(k)
350
+ # -- compute qk ---
351
+ qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2.
352
+ if not PRESCALE_QK:
353
+ qk *= SM_SCALE
354
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
355
+ # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements,
356
+ # which is larger than the actual number of elements. To avoid access memory out of bound,
357
+ # we need to mask out the elements that are out of Q_LEN & KV_LEN.
358
+ m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None)
359
+ n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None)
360
+
361
+ tmp0 = (qk)
362
+ post_mod_scores = tmp0
363
+
364
+
365
+ if CHECK_BLOCK_BOUNDARY:
366
+ # Mask out the elements that are out of the KV_LEN for non divisible seqlen.
367
+ post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf"))
368
+
369
+ if not IS_FULL_BLOCKS:
370
+ tmp1 = tl.full([1], False, tl.int1)
371
+ tmp2 = (m)
372
+ tmp3 = (n)
373
+ tmp4 = tmp2 >= tmp3
374
+ tmp5 = tmp3.to(tl.int64)
375
+ tmp6 = (off_z)
376
+ tmp7 = tl.load(in_ptr9 + tmp6)
377
+ tmp8 = tmp5 < tmp7
378
+ tmp9 = tmp2.to(tl.int64)
379
+ tmp10 = tmp9 < tmp7
380
+ tmp11 = tmp8 & tmp10
381
+ tmp12 = tmp4 & tmp11
382
+ tmp13 = tmp1 | tmp12
383
+ tmp14 = tl.full([1], 2048, tl.int32)
384
+ tmp15 = tmp3 >= tmp14
385
+ tmp16 = (tmp3 % tmp14)
386
+ tmp17 = tl.full([1], 0, tl.int32)
387
+ tmp18 = tmp16 != tmp17
388
+ tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0
389
+ tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0
390
+ tmp21 = tmp19 != tmp20
391
+ tmp22 = tmp18 & tmp21
392
+ tmp23 = tmp16 + tmp14
393
+ tmp24 = tl.where(tmp22, tmp23, tmp16)
394
+ tmp25 = tmp24.to(tl.int64)
395
+ tmp26 = tmp25 < tmp7
396
+ tmp27 = tmp15 & tmp26
397
+ tmp28 = tmp3 - tmp2
398
+ tmp29 = (tmp28 % tmp14)
399
+ tmp30 = tmp29 != tmp17
400
+ tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0
401
+ tmp32 = tmp31 != tmp20
402
+ tmp33 = tmp30 & tmp32
403
+ tmp34 = tmp29 + tmp14
404
+ tmp35 = tl.where(tmp33, tmp34, tmp29)
405
+ tmp36 = tmp35 == tmp17
406
+ tmp37 = tmp27 & tmp36
407
+ tmp38 = tmp13 | tmp37
408
+ mask_mod_output = tmp38
409
+
410
+
411
+ if CHECK_BLOCK_BOUNDARY:
412
+ mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False)
413
+ # apply mask for partially unmasked blocks
414
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
415
+
416
+ if not PRESCALE_QK:
417
+ post_mod_scores *= RCP_LN2
418
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
419
+
420
+ # -- compute scaling constant ---
421
+ m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1))
422
+ if not ROWS_GUARANTEED_SAFE:
423
+ masked_out_rows = (m_ij == float("-inf"))
424
+ m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
425
+ else:
426
+ m_ij_masked = m_ij
427
+
428
+ alpha = tl.math.exp2(m_i - m_ij_masked)
429
+ p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None])
430
+
431
+ # NB: l_i update is pulled up here since it's a bit faster
432
+ # NB: For headdim=256, it's faster to move it back down to after m_i =
433
+ # m_ij
434
+ l_i = l_i * alpha + tl.sum(p, 1)
435
+ # # -- scale and update acc --
436
+ acc = acc * alpha[:, None]
437
+ # Calculate offsets for V loading - reuse kv_base_offset from K loading
438
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
439
+ v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
440
+ acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION)
441
+
442
+ # -- update m_i
443
+ m_i = m_ij
444
+
445
+ return acc, l_i, m_i
446
+
447
+ @triton.jit
448
+ def forward_inner(
449
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1,
450
+ q, K, V,
451
+ desc_k, desc_v, Q_LEN, KV_LEN,
452
+ # accumulated values
453
+ acc, l_i, m_i,
454
+ # Offsets used as inputs to score_mod & mask_mod
455
+ # of size [BLOCK_M, BLOCK_N] or scalar.
456
+ off_z, off_h, offs_m, offs_n,
457
+ # Offsets needed for TMA loads
458
+ kv_start,
459
+ # blocksparse data
460
+ kv_indices, kv_num_blocks,
461
+ # start kv and end kv block
462
+ block_n_start, block_n_end,
463
+ MATMUL_PRECISION,
464
+ # Strides for K and V
465
+ stride_kk, stride_kn, stride_vn, stride_vk,
466
+ IS_FULL_BLOCKS,
467
+ ):
468
+ # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
469
+ PRESCALE_QK : tl.constexpr = False
470
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
471
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
472
+ WRITE_DQ : tl.constexpr = True
473
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
474
+ OUTPUT_MAX : tl.constexpr = False
475
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
476
+ IS_DIVISIBLE : tl.constexpr = False
477
+ SM_SCALE : tl.constexpr = 0.08838834764831843
478
+ GQA_SHARED_HEADS : tl.constexpr = 4
479
+ HAS_FULL_BLOCKS : tl.constexpr = True
480
+ QK_HEAD_DIM : tl.constexpr = 128
481
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
482
+ V_HEAD_DIM : tl.constexpr = 128
483
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
484
+ SAFE_HEAD_DIM : tl.constexpr = True
485
+ USE_TMA : tl.constexpr = False
486
+ BLOCK_M : tl.constexpr = 128
487
+ BLOCK_N : tl.constexpr = 64
488
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
489
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
490
+ INDEX_DTYPE : tl.constexpr = tl.int32
491
+
492
+
493
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
494
+ RCP_LN2: tl.constexpr = 1.44269504
495
+
496
+ if PRESCALE_QK:
497
+ q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
498
+
499
+ kv_offset = 0
500
+
501
+ # loop over k, v and update accumulator until block_n_end
502
+ for start_n in range(block_n_start, block_n_end):
503
+ # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention.
504
+ if IS_DIVISIBLE:
505
+ acc, l_i, m_i = forward_block_mn(
506
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1,
507
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
508
+ # accumulated values
509
+ acc, l_i, m_i,
510
+ # Offsets
511
+ off_z, off_h, offs_m, offs_n,
512
+ # Offsets needed for TMA loads
513
+ kv_start,
514
+ kv_offset,
515
+ MATMUL_PRECISION, RCP_LN2,
516
+ # Strides for K and V
517
+ stride_kk, stride_kn, stride_vn, stride_vk,
518
+ IS_FULL_BLOCKS,
519
+ )
520
+ else:
521
+ # Benchmark shows even we applied mod & mask to each block for non divisible seqlen,
522
+ # it's on par or slightly faster than only applying to the last block in fwd.
523
+ # However, we choose different strategy for bwd, where we only apply mod & mask
524
+ # to the last block because it's faster a lot.
525
+ acc, l_i, m_i = forward_block_mn(
526
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1,
527
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
528
+ # accumulated values
529
+ acc, l_i, m_i,
530
+ # Offsets
531
+ off_z, off_h, offs_m, offs_n,
532
+ # Offsets needed for TMA loads
533
+ kv_start,
534
+ kv_offset,
535
+ MATMUL_PRECISION, RCP_LN2,
536
+ # Strides for K and V
537
+ stride_kk, stride_kn, stride_vn, stride_vk,
538
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
539
+ )
540
+
541
+
542
+
543
+ offset = get_offset_for_next_block(
544
+ start_n, kv_indices, kv_num_blocks,
545
+ SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS
546
+ )
547
+
548
+ offs_n = offs_n + offset
549
+ kv_offset += offset
550
+
551
+
552
+ return acc, l_i, m_i
SpecForge-ext/cache/compiled_kernels/e6/ce6sgne5yx3pyeim455xwwbqvpu2da3rro3rzyopm3res7mhkspf.py ADDED
@@ -0,0 +1,835 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+
9
+ @triton_heuristics.template(
10
+
11
+ num_stages=3,
12
+ num_warps=8,
13
+ triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
15
+
16
+ )
17
+ @triton.jit
18
+ def triton_flex_attention_backward(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0):
19
+ PRESCALE_QK : tl.constexpr = False
20
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
21
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
22
+ WRITE_DQ : tl.constexpr = True
23
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
24
+ OUTPUT_MAX : tl.constexpr = False
25
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
26
+ IS_DIVISIBLE : tl.constexpr = True
27
+ SM_SCALE : tl.constexpr = 0.08838834764831843
28
+ GQA_SHARED_HEADS : tl.constexpr = 4
29
+ HAS_FULL_BLOCKS : tl.constexpr = True
30
+ QK_HEAD_DIM : tl.constexpr = 128
31
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
32
+ V_HEAD_DIM : tl.constexpr = 128
33
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
34
+ SAFE_HEAD_DIM : tl.constexpr = True
35
+ BLOCK_M1 : tl.constexpr = 64
36
+ BLOCK_N1 : tl.constexpr = 128
37
+ BLOCK_M2 : tl.constexpr = 128
38
+ BLOCK_N2 : tl.constexpr = 64
39
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
40
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
41
+ INDEX_DTYPE : tl.constexpr = tl.int32
42
+ Q = arg_Q
43
+ K = arg_K
44
+ V = arg_V
45
+ LSE = arg_LSE
46
+ DELTA = arg_DELTA
47
+ DO = arg_DO
48
+ DQ = arg_DQ
49
+ DV = arg_DV
50
+ KV_NUM_BLKS = arg_KV_NUM_BLKS
51
+ KV_IDX = arg_KV_IDX
52
+ Q_NUM_BLKS = arg_Q_NUM_BLKS
53
+ Q_IDX = arg_Q_IDX
54
+ FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
55
+ FULL_KV_IDX = arg_FULL_KV_IDX
56
+ FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS
57
+ FULL_Q_IDX = arg_FULL_Q_IDX
58
+
59
+ # Sub notation for this kernel:
60
+ #
61
+ # Q: Query, K: Key, V: Value
62
+ # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype)
63
+ # DELTA: Precomputed sum(OUT*DO, axis=-1)
64
+ # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value
65
+ # DK: Derivative of Key, is the written to via the store_output call due to some limitations with
66
+ # inductor codegen
67
+ # M: Number of queries, N: Number of keys/values
68
+ # QK_HEAD_DIM: The dimension of the query and key embeddings
69
+ # V_HEAD_DIM: The dimension of the value embeddings
70
+ # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim
71
+ # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
72
+ # (Modifiable) Performance tuning options
73
+ # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block.
74
+ # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V.
75
+ # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q.
76
+ # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block.
77
+ #
78
+ # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
79
+ # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
80
+ # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
81
+ # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query.
82
+ # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query.
83
+ # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
84
+ # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
85
+ # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query.
86
+ # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query.
87
+
88
+ # The below are kernel options that can be applied for certain score_mods,
89
+ # or involve a numerics vs. perf tradeoff
90
+ # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
91
+ # about 20% more numerical error, but slightly faster.
92
+
93
+ # Define strides of inputs
94
+ stride_qz, stride_qh, stride_qm, stride_qd = 8388608, 128, 4096, 1
95
+ stride_kz, stride_kh, stride_kn, stride_kd = 2097152, 262144, 128, 1
96
+ stride_vz, stride_vh, stride_vn, stride_vd = 2097152, 262144, 128, 1
97
+ stride_doz, stride_doh, stride_dom, stride_dod = 8388608, 262144, 128, 1
98
+
99
+ stride_dqz, stride_dqh, stride_dqm, stride_dqd = 8388608, 128, 4096, 1
100
+ stride_dvz, stride_dvh, stride_dvm, stride_dvd = 2097152, 262144, 128, 1
101
+
102
+ ZQ = 2
103
+ HQ = 32
104
+ HKV = 8
105
+ Q_LEN = 2048
106
+ ZKV = 2
107
+ KV_LEN = 2048
108
+
109
+ MATMUL_PRECISION = Q.dtype.element_ty
110
+
111
+ pid = tl.program_id(0).to(INDEX_DTYPE)
112
+ NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1)
113
+ NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2)
114
+
115
+ off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx
116
+ off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx
117
+ off_zkv = off_zq % ZKV # kv batch idx
118
+
119
+ SPARSE_Z = 2
120
+ SPARSE_HQ = 1
121
+
122
+ sparse_idx_z = off_zq % SPARSE_Z
123
+
124
+ k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64)
125
+ v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64)
126
+ # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
127
+ # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
128
+ dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64)
129
+
130
+ # offset K, V, DV pointers for batch/kv-head
131
+ K += k_adj
132
+ V += v_adj
133
+ DV += dv_adj
134
+
135
+ RCP_LN2 = 1.44269504
136
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
137
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
138
+
139
+ if pid >= NUM_KV_BLOCKS:
140
+ off_pid = pid - NUM_KV_BLOCKS
141
+ # THIS BLOCK DOES DQ
142
+ SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2)
143
+ SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
144
+ off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS
145
+ start_m2_block = off_pid % NUM_Q_BLOCKS
146
+ off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE
147
+ stride_kv_num_blks_h = 16
148
+ stride_kv_idx_h = 256
149
+ stride_kv_idx_m = 16
150
+
151
+ sparse_idx_hq2 = off_hq2 % SPARSE_HQ
152
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2
153
+
154
+ sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask
155
+ sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950
156
+
157
+ # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
158
+ q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64)
159
+ do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64)
160
+ dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64)
161
+ off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64)
162
+
163
+ Q2 = Q + q_adj2
164
+ DO2 = DO + do_adj2
165
+ # TODO: This does not work if DQ is not the same layout as Q (for example,
166
+ # if Q is broadcasted)
167
+ DQ2 = DQ + dq_adj2
168
+ LSE2 = LSE + off_chz2
169
+ DELTA2 = DELTA + off_chz2
170
+
171
+ # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32)
172
+ dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
173
+
174
+ start_m2 = start_m2_block * BLOCK_M2
175
+ offs_m2 = start_m2 + tl.arange(0, BLOCK_M2)
176
+
177
+ # load Q and do: they stay in SRAM throughout the inner loop.
178
+ q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
179
+ do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
180
+
181
+ if PRESCALE_QK:
182
+ q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
183
+
184
+ if IS_DIVISIBLE:
185
+ Di = tl.load(DELTA2 + offs_m2)
186
+ lse = tl.load(LSE2 + offs_m2)
187
+ else:
188
+ Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN)
189
+ lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN)
190
+ lse = tl.where(lse == -float("inf"), 0.0, lse)
191
+ lse = lse[:, None]
192
+
193
+ # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
194
+ # KV_IDX and KV_NUM_BLKS are always contiguous.
195
+ kv_indices = KV_IDX + sparse_kv_idx_offset
196
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
197
+ sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
198
+
199
+ offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
200
+ dq = bwd_dq_inner(
201
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
202
+ K, V,
203
+ dq, q, do, Di, lse,
204
+ off_zq, off_hq2, offs_m2, offs_n2,
205
+ stride_kn, stride_kd, stride_vn, stride_vd,
206
+ kv_indices, sparse_kv_num_blocks,
207
+ MATMUL_PRECISION,
208
+ IS_FULL_BLOCKS=False,
209
+ )
210
+
211
+ if HAS_FULL_BLOCKS:
212
+ # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
213
+ # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
214
+ kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
215
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
216
+ sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
217
+
218
+ offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
219
+ dq = bwd_dq_inner(
220
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
221
+ K, V,
222
+ dq, q, do, Di, lse,
223
+ off_zq, off_hq2, offs_m2, offs_n2,
224
+ stride_kn, stride_kd, stride_vn, stride_vd,
225
+ kv_indices, sparse_kv_num_blocks,
226
+ MATMUL_PRECISION,
227
+ IS_FULL_BLOCKS=True,
228
+ )
229
+
230
+ # Write back dQ.
231
+ dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd
232
+ dq *= SM_SCALE
233
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
234
+ tl.store(dq_ptrs, dq)
235
+ else:
236
+ tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM))
237
+ else:
238
+ # THIS BLOCK DOES DK & DV
239
+ SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
240
+ SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1)
241
+
242
+ pid_mask = pid // SPARSE_KV_MULTIPLE
243
+
244
+ stride_q_num_blks_h = 16
245
+ stride_q_idx_h = 256
246
+ stride_q_idx_n = 16
247
+
248
+
249
+ dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
250
+ dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
251
+
252
+ start_n1 = pid * BLOCK_N1
253
+ offs_n1 = start_n1 + tl.arange(0, BLOCK_N1)
254
+
255
+ # load K and V: they stay in SRAM throughout the inner loop.
256
+ k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
257
+ v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
258
+
259
+ if PRESCALE_QK:
260
+ k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
261
+
262
+ for off_g in range(0, GQA_SHARED_HEADS):
263
+ off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g
264
+
265
+ # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
266
+ q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64)
267
+ do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64)
268
+ dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64)
269
+ off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64)
270
+
271
+ Q1 = Q + q_adj1
272
+ DO1 = DO + do_adj1
273
+ # TODO: This does not work if DQ is not the same layout as Q (for example,
274
+ # if Q is broadcasted)
275
+ LSE1 = LSE + off_chz1
276
+ DELTA1 = DELTA + off_chz1
277
+
278
+ sparse_idx_hq1 = off_hq1 % SPARSE_HQ
279
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1
280
+
281
+ sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask
282
+ sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950
283
+
284
+ # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
285
+ # Q_IDX and Q_NUM_BLKS are always contiguous.
286
+ q_indices = Q_IDX + sparse_q_idx_offset
287
+ q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
288
+ sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset)
289
+
290
+ offs_m1 = q_start + tl.arange(0, BLOCK_M1)
291
+ dk, dv = bwd_dkdv_inner(
292
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
293
+ Q1, DO1, DELTA1, LSE1,
294
+ dk, dv, k, v,
295
+ off_zq, off_hq1, offs_n1, offs_m1,
296
+ stride_qm, stride_qd, stride_dom, stride_dod,
297
+ q_indices, sparse_q_num_blocks,
298
+ MATMUL_PRECISION,
299
+ IS_FULL_BLOCKS=False,
300
+ )
301
+
302
+
303
+ if HAS_FULL_BLOCKS:
304
+ # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
305
+ # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous.
306
+ q_indices = FULL_Q_IDX + sparse_q_idx_offset
307
+ q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
308
+ sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset)
309
+
310
+ offs_m1 = q_start + tl.arange(0, BLOCK_M1)
311
+ dk, dv = bwd_dkdv_inner(
312
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
313
+ Q1, DO1, DELTA1, LSE1,
314
+ dk, dv, k, v,
315
+ off_zq, off_hq1, offs_n1, offs_m1,
316
+ stride_qm, stride_qd, stride_dom, stride_dod,
317
+ q_indices, sparse_q_num_blocks,
318
+ MATMUL_PRECISION,
319
+ IS_FULL_BLOCKS=True,
320
+ )
321
+
322
+ # Write back dV and dK.
323
+ dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd
324
+
325
+ index_n = offs_n1[:, None]
326
+ index_k = offs_k[None, :]
327
+ index_v = offs_v[None, :]
328
+
329
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
330
+ tl.store(dv_ptrs, dv)
331
+ else:
332
+ tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM))
333
+
334
+ dk *= SM_SCALE
335
+
336
+ if SAFE_HEAD_DIM:
337
+ mask = index_n < KV_LEN
338
+ else:
339
+ mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM)
340
+
341
+ # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
342
+ # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
343
+ tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED])
344
+ xindex = index_k + 128*index_n + 262144*off_hkv + 2097152*off_zq
345
+ tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask)
346
+
347
+ @triton.jit
348
+ def bwd_dq_inner(
349
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
350
+ K, V, # pointers
351
+ dq, q, do, Di, lse,
352
+ off_z, off_hq, offs_m2, offs_n2,
353
+ stride_kn, stride_kd, stride_vn, stride_vd,
354
+ kv_indices, sparse_kv_num_blocks,
355
+ MATMUL_PRECISION,
356
+ IS_FULL_BLOCKS,
357
+ ):
358
+ PRESCALE_QK : tl.constexpr = False
359
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
360
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
361
+ WRITE_DQ : tl.constexpr = True
362
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
363
+ OUTPUT_MAX : tl.constexpr = False
364
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
365
+ IS_DIVISIBLE : tl.constexpr = True
366
+ SM_SCALE : tl.constexpr = 0.08838834764831843
367
+ GQA_SHARED_HEADS : tl.constexpr = 4
368
+ HAS_FULL_BLOCKS : tl.constexpr = True
369
+ QK_HEAD_DIM : tl.constexpr = 128
370
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
371
+ V_HEAD_DIM : tl.constexpr = 128
372
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
373
+ SAFE_HEAD_DIM : tl.constexpr = True
374
+ BLOCK_M1 : tl.constexpr = 64
375
+ BLOCK_N1 : tl.constexpr = 128
376
+ BLOCK_M2 : tl.constexpr = 128
377
+ BLOCK_N2 : tl.constexpr = 64
378
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
379
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
380
+ INDEX_DTYPE : tl.constexpr = tl.int32
381
+
382
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
383
+ RCP_LN2: tl.constexpr = 1.44269504
384
+ Q_LEN = 2048
385
+ KV_LEN = 2048
386
+
387
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
388
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
389
+
390
+ kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd
391
+ vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd
392
+ # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
393
+ tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
394
+
395
+ hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1))
396
+
397
+ for start_n in range(0, hi):
398
+ dq = bwd_dq_block_mn(
399
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
400
+ dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
401
+ off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
402
+ stride_kn, stride_kd, stride_vn, stride_vd,
403
+ kv_indices, sparse_kv_num_blocks,
404
+ MATMUL_PRECISION, RCP_LN2,
405
+ IS_FULL_BLOCKS,
406
+ )
407
+
408
+ # Increment pointers.
409
+ offset = get_offset_for_next_block(
410
+ start_n, kv_indices, sparse_kv_num_blocks,
411
+ SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS
412
+ )
413
+
414
+ kT_ptrs += offset * stride_kn
415
+ vT_ptrs += offset * stride_vn
416
+
417
+ offs_n2 += offset
418
+
419
+ return dq
420
+
421
+
422
+ @triton.jit
423
+ def bwd_dq_block_mn(
424
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
425
+ dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
426
+ off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
427
+ stride_kn, stride_kd, stride_vn, stride_vd,
428
+ kv_indices, sparse_kv_num_blocks,
429
+ MATMUL_PRECISION, RCP_LN2,
430
+ IS_FULL_BLOCKS,
431
+ ):
432
+ PRESCALE_QK : tl.constexpr = False
433
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
434
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
435
+ WRITE_DQ : tl.constexpr = True
436
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
437
+ OUTPUT_MAX : tl.constexpr = False
438
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
439
+ IS_DIVISIBLE : tl.constexpr = True
440
+ SM_SCALE : tl.constexpr = 0.08838834764831843
441
+ GQA_SHARED_HEADS : tl.constexpr = 4
442
+ HAS_FULL_BLOCKS : tl.constexpr = True
443
+ QK_HEAD_DIM : tl.constexpr = 128
444
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
445
+ V_HEAD_DIM : tl.constexpr = 128
446
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
447
+ SAFE_HEAD_DIM : tl.constexpr = True
448
+ BLOCK_M1 : tl.constexpr = 64
449
+ BLOCK_N1 : tl.constexpr = 128
450
+ BLOCK_M2 : tl.constexpr = 128
451
+ BLOCK_N2 : tl.constexpr = 64
452
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
453
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
454
+ INDEX_DTYPE : tl.constexpr = tl.int32
455
+
456
+
457
+ # NB reversed order to since K is transposed
458
+ kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN)
459
+ qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION)
460
+ if not PRESCALE_QK:
461
+ qk *= SM_SCALE
462
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
463
+ pre_mod_scores = qk
464
+ n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None)
465
+ # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim
466
+ # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary
467
+ m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None)
468
+
469
+ tmp0 = (qk)
470
+ post_mod_scores = tmp0
471
+
472
+
473
+
474
+
475
+ if not IS_DIVISIBLE:
476
+ post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf"))
477
+
478
+ if not IS_FULL_BLOCKS:
479
+ tmp1 = tl.full([1], False, tl.int1)
480
+ tmp2 = (m)
481
+ tmp3 = (n)
482
+ tmp4 = tmp2 >= tmp3
483
+ tmp5 = tmp3.to(tl.int64)
484
+ tmp6 = (off_z)
485
+ tmp7 = tl.load(in_ptr16 + tmp6)
486
+ tmp8 = tmp5 < tmp7
487
+ tmp9 = tmp2.to(tl.int64)
488
+ tmp10 = tmp9 < tmp7
489
+ tmp11 = tmp8 & tmp10
490
+ tmp12 = tmp4 & tmp11
491
+ tmp13 = tmp1 | tmp12
492
+ tmp14 = tl.full([1], 2048, tl.int32)
493
+ tmp15 = tmp3 >= tmp14
494
+ tmp16 = (tmp3 % tmp14)
495
+ tmp17 = tl.full([1], 0, tl.int32)
496
+ tmp18 = tmp16 != tmp17
497
+ tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0
498
+ tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0
499
+ tmp21 = tmp19 != tmp20
500
+ tmp22 = tmp18 & tmp21
501
+ tmp23 = tmp16 + tmp14
502
+ tmp24 = tl.where(tmp22, tmp23, tmp16)
503
+ tmp25 = tmp24.to(tl.int64)
504
+ tmp26 = tmp25 < tmp7
505
+ tmp27 = tmp15 & tmp26
506
+ tmp28 = tmp3 - tmp2
507
+ tmp29 = (tmp28 % tmp14)
508
+ tmp30 = tmp29 != tmp17
509
+ tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0
510
+ tmp32 = tmp31 != tmp20
511
+ tmp33 = tmp30 & tmp32
512
+ tmp34 = tmp29 + tmp14
513
+ tmp35 = tl.where(tmp33, tmp34, tmp29)
514
+ tmp36 = tmp35 == tmp17
515
+ tmp37 = tmp27 & tmp36
516
+ tmp38 = tmp13 | tmp37
517
+ mask_mod_output = tmp38
518
+
519
+
520
+ # apply mask for partial masked block
521
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
522
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
523
+ if not PRESCALE_QK:
524
+ post_mod_scores *= RCP_LN2
525
+ p = tl.math.exp2(post_mod_scores - lse)
526
+ # Compute dP and dS.
527
+ # NB reversed order to since V is transposed
528
+ vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN)
529
+
530
+ dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION)
531
+ ds = p * (dp - Di[:, None])
532
+ # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
533
+ tmp39 = (ds)
534
+ grad_scores = tmp39
535
+
536
+
537
+ if not IS_DIVISIBLE:
538
+ grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0)
539
+
540
+ # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
541
+ if WRITE_DQ:
542
+ scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN)
543
+
544
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
545
+ ds = grad_scores
546
+
547
+ if not IS_FULL_BLOCKS:
548
+ # (grads) apply mask for partially unmasked block
549
+ ds = tl.where(mask_mod_output, ds, 0.0)
550
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
551
+ ds = ds.to(MATMUL_PRECISION)
552
+ # Compute dQ.
553
+ dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION)
554
+
555
+ return dq
556
+
557
+
558
+ @triton.jit
559
+ def bwd_dkdv_inner(
560
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
561
+ Q, DO, DELTA, LSE, # pointers
562
+ dk, dv, k, v,
563
+ off_z, off_hq, offs_n1, offs_m1,
564
+ stride_qm, stride_qd, stride_dom, stride_dod,
565
+ q_indices, sparse_q_num_blocks,
566
+ MATMUL_PRECISION,
567
+ IS_FULL_BLOCKS,
568
+ ):
569
+ PRESCALE_QK : tl.constexpr = False
570
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
571
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
572
+ WRITE_DQ : tl.constexpr = True
573
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
574
+ OUTPUT_MAX : tl.constexpr = False
575
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
576
+ IS_DIVISIBLE : tl.constexpr = True
577
+ SM_SCALE : tl.constexpr = 0.08838834764831843
578
+ GQA_SHARED_HEADS : tl.constexpr = 4
579
+ HAS_FULL_BLOCKS : tl.constexpr = True
580
+ QK_HEAD_DIM : tl.constexpr = 128
581
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
582
+ V_HEAD_DIM : tl.constexpr = 128
583
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
584
+ SAFE_HEAD_DIM : tl.constexpr = True
585
+ BLOCK_M1 : tl.constexpr = 64
586
+ BLOCK_N1 : tl.constexpr = 128
587
+ BLOCK_M2 : tl.constexpr = 128
588
+ BLOCK_N2 : tl.constexpr = 64
589
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
590
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
591
+ INDEX_DTYPE : tl.constexpr = tl.int32
592
+
593
+ SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
594
+ RCP_LN2: tl.constexpr = 1.44269504
595
+ Q_LEN = 2048
596
+ KV_LEN = 2048
597
+
598
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
599
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
600
+
601
+ qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd
602
+ do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod
603
+ # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
604
+ tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
605
+
606
+ # The minimum is needed to handle the case where we run with a super large
607
+ # SPARSE_BLOCK_SIZE (i.e. no block-mask!)
608
+ hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1))
609
+
610
+ for start_m in range(0, hi):
611
+ dk, dv = bwd_dkdv_block_mn(
612
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
613
+ dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
614
+ off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
615
+ stride_qm, stride_qd, stride_dom, stride_dod,
616
+ q_indices, sparse_q_num_blocks,
617
+ MATMUL_PRECISION, RCP_LN2,
618
+ IS_FULL_BLOCKS,
619
+ )
620
+ # Increment pointers.
621
+ offset = get_offset_for_next_block(
622
+ start_m, q_indices, sparse_q_num_blocks,
623
+ SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS
624
+ )
625
+
626
+ qT_ptrs += offset * stride_qm
627
+ do_ptrs += offset * stride_dom
628
+ offs_m1 += offset
629
+
630
+ return dk, dv
631
+
632
+
633
+ @triton.jit
634
+ def bwd_dkdv_block_mn(
635
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
636
+ dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
637
+ off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
638
+ stride_qm, stride_qd, stride_dom, stride_dod,
639
+ q_indices, sparse_q_num_blocks,
640
+ MATMUL_PRECISION, RCP_LN2,
641
+ IS_FULL_BLOCKS,
642
+ ):
643
+ PRESCALE_QK : tl.constexpr = False
644
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
645
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
646
+ WRITE_DQ : tl.constexpr = True
647
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
648
+ OUTPUT_MAX : tl.constexpr = False
649
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
650
+ IS_DIVISIBLE : tl.constexpr = True
651
+ SM_SCALE : tl.constexpr = 0.08838834764831843
652
+ GQA_SHARED_HEADS : tl.constexpr = 4
653
+ HAS_FULL_BLOCKS : tl.constexpr = True
654
+ QK_HEAD_DIM : tl.constexpr = 128
655
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
656
+ V_HEAD_DIM : tl.constexpr = 128
657
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
658
+ SAFE_HEAD_DIM : tl.constexpr = True
659
+ BLOCK_M1 : tl.constexpr = 64
660
+ BLOCK_N1 : tl.constexpr = 128
661
+ BLOCK_M2 : tl.constexpr = 128
662
+ BLOCK_N2 : tl.constexpr = 64
663
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
664
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
665
+ INDEX_DTYPE : tl.constexpr = tl.int32
666
+
667
+
668
+ # NB reversed order since Q is transposed
669
+ qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN)
670
+ # Load LSE before computing qk to reduce pipeline stall.
671
+ if IS_DIVISIBLE:
672
+ lse = tl.load(LSE + offs_m1)
673
+ else:
674
+ lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN)
675
+ lse = tl.where(lse == -float("inf"), 0.0, lse)
676
+ qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION)
677
+ if not PRESCALE_QK:
678
+ qkT *= SM_SCALE
679
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
680
+ m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None)
681
+ # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim
682
+ # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary
683
+ n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None)
684
+
685
+ pre_mod_scores = qkT
686
+ tmp40 = (qkT)
687
+ post_mod_scores = tmp40
688
+
689
+
690
+
691
+ if not IS_DIVISIBLE:
692
+ post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf"))
693
+
694
+ if not IS_FULL_BLOCKS:
695
+ tmp41 = tl.full([1], False, tl.int1)
696
+ tmp42 = (m)
697
+ tmp43 = (n)
698
+ tmp44 = tmp42 >= tmp43
699
+ tmp45 = tmp43.to(tl.int64)
700
+ tmp46 = (off_z)
701
+ tmp47 = tl.load(in_ptr16 + tmp46)
702
+ tmp48 = tmp45 < tmp47
703
+ tmp49 = tmp42.to(tl.int64)
704
+ tmp50 = tmp49 < tmp47
705
+ tmp51 = tmp48 & tmp50
706
+ tmp52 = tmp44 & tmp51
707
+ tmp53 = tmp41 | tmp52
708
+ tmp54 = tl.full([1], 2048, tl.int32)
709
+ tmp55 = tmp43 >= tmp54
710
+ tmp56 = (tmp43 % tmp54)
711
+ tmp57 = tl.full([1], 0, tl.int32)
712
+ tmp58 = tmp56 != tmp57
713
+ tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0
714
+ tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0
715
+ tmp61 = tmp59 != tmp60
716
+ tmp62 = tmp58 & tmp61
717
+ tmp63 = tmp56 + tmp54
718
+ tmp64 = tl.where(tmp62, tmp63, tmp56)
719
+ tmp65 = tmp64.to(tl.int64)
720
+ tmp66 = tmp65 < tmp47
721
+ tmp67 = tmp55 & tmp66
722
+ tmp68 = tmp43 - tmp42
723
+ tmp69 = (tmp68 % tmp54)
724
+ tmp70 = tmp69 != tmp57
725
+ tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0
726
+ tmp72 = tmp71 != tmp60
727
+ tmp73 = tmp70 & tmp72
728
+ tmp74 = tmp69 + tmp54
729
+ tmp75 = tl.where(tmp73, tmp74, tmp69)
730
+ tmp76 = tmp75 == tmp57
731
+ tmp77 = tmp67 & tmp76
732
+ tmp78 = tmp53 | tmp77
733
+ mask_mod_output = tmp78
734
+
735
+ # (grads) apply mask for fully masked block
736
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
737
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
738
+ if not PRESCALE_QK:
739
+ post_mod_scores *= RCP_LN2
740
+ pT = tl.math.exp2(post_mod_scores - lse[None, :])
741
+ do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
742
+ # Compute dV.
743
+ ppT = pT
744
+ dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION)
745
+ if IS_DIVISIBLE:
746
+ Di = tl.load(DELTA + offs_m1)
747
+ else:
748
+ Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN)
749
+ # Compute dP and dS.
750
+ dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION)
751
+ dsT = pT * (dpT - Di[None, :])
752
+ # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
753
+ tmp79 = (dsT)
754
+ grad_scores = tmp79
755
+
756
+
757
+
758
+ if not IS_DIVISIBLE:
759
+ grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0)
760
+
761
+ # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
762
+ if not WRITE_DQ:
763
+ idx_b = off_z
764
+ idx_h = off_hq
765
+ idx_m = m
766
+ idx_n = n
767
+ scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN)
768
+
769
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
770
+ dsT = grad_scores
771
+ if not IS_FULL_BLOCKS:
772
+ # (grads) apply mask for partially unmasked block
773
+ dsT = tl.where(mask_mod_output, dsT, 0.0)
774
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
775
+ dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION)
776
+
777
+ return dk, dv
778
+
779
+ # Utility triton funcs
780
+ @triton.jit
781
+ def get_offset_for_next_block(
782
+ loop_iter, col_indices, total_blocks,
783
+ SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
784
+ BLOCKS_ARE_CONTIGUOUS: tl.constexpr
785
+ ):
786
+ if BLOCKS_ARE_CONTIGUOUS:
787
+ return BLOCK
788
+ cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
789
+ cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
790
+ next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
791
+ needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
792
+ jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
793
+ offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
794
+ return offset
795
+
796
+ @triton.jit
797
+ def get_bounded_indices(indices, max_len=None):
798
+ return indices % max_len if max_len is not None else indices
799
+
800
+ @triton.jit
801
+ def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
802
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
803
+ return tl.load(block_ptr)
804
+ elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
805
+ return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
806
+ elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
807
+ return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
808
+ else:
809
+ return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
810
+
811
+ @triton.jit
812
+ def load_checked_2d(
813
+ ptr,
814
+ offs_m,
815
+ offs_n,
816
+ stride_m,
817
+ stride_n,
818
+ IS_DIVISIBLE_M: tl.constexpr,
819
+ IS_DIVISIBLE_N: tl.constexpr,
820
+ M_LEN: tl.constexpr,
821
+ N_LEN: tl.constexpr,
822
+ ):
823
+ # Calculate final pointer if strides are provided
824
+ if stride_m is not None and stride_n is not None:
825
+ ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
826
+
827
+ # Handle all masking cases
828
+ if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
829
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
830
+ elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
831
+ return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
832
+ elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
833
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
834
+ else: # Both divisible
835
+ return tl.load(ptr)
SpecForge-ext/cache/compiled_kernels/f6/cf6ayxqoma6zlumium5vkfjxneuep3h7lxmtssd73sg7bynrgpyn.py ADDED
@@ -0,0 +1,552 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+
9
+ @triton_heuristics.template(
10
+
11
+ num_stages=3,
12
+ num_warps=8,
13
+ triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'in_ptr9': '*i64', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
15
+
16
+ )
17
+ @triton.jit
18
+ def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0):
19
+ PRESCALE_QK : tl.constexpr = False
20
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
21
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
22
+ WRITE_DQ : tl.constexpr = True
23
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
24
+ OUTPUT_MAX : tl.constexpr = False
25
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
26
+ IS_DIVISIBLE : tl.constexpr = True
27
+ SM_SCALE : tl.constexpr = 0.08838834764831843
28
+ GQA_SHARED_HEADS : tl.constexpr = 4
29
+ HAS_FULL_BLOCKS : tl.constexpr = True
30
+ QK_HEAD_DIM : tl.constexpr = 128
31
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
32
+ V_HEAD_DIM : tl.constexpr = 128
33
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
34
+ SAFE_HEAD_DIM : tl.constexpr = True
35
+ USE_TMA : tl.constexpr = False
36
+ BLOCK_M : tl.constexpr = 128
37
+ BLOCK_N : tl.constexpr = 64
38
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
39
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
40
+ INDEX_DTYPE : tl.constexpr = tl.int32
41
+ Q = arg_Q
42
+ K = arg_K
43
+ V = arg_V
44
+ LSE = arg_LSE
45
+ MAX = arg_MAX
46
+ KV_NUM_BLKS = arg_KV_NUM_BLKS
47
+ KV_IDX = arg_KV_IDX
48
+ FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
49
+ FULL_KV_IDX = arg_FULL_KV_IDX
50
+
51
+ # Sub notation for this kernel:
52
+ #
53
+ # Q: Query, K: Key, V: Value
54
+ # M: Number of queries, N: Number of keys/values, D: Model dimension
55
+ # QK_HEAD_DIM: The dimension of the query and key embeddings
56
+ # V_HEAD_DIM: The dimension of the value embeddings
57
+ # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head
58
+ # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
59
+ #
60
+ # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
61
+ # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
62
+ # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
63
+ # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
64
+ # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
65
+ #
66
+ # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad
67
+ #
68
+ # (Modifiable) Performance tuning options
69
+ # BLOCK_M: The thread block size across the seqlen dim of Q.
70
+ # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block.
71
+
72
+ # The below are kernel options that can be applied for certain score_mods,
73
+ # or involve a numerics vs. perf tradeoff
74
+ # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
75
+ # about 20% more numerical error, but slightly faster.
76
+ # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row
77
+ # is not masked out? If so, we can skip an extra safety check
78
+ # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are
79
+ # contiguous? If so, we don't need to do an indirect jump for every block
80
+
81
+ tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0)
82
+ tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0)
83
+
84
+ # Define strides of inputs
85
+ stride_qz, stride_qh, stride_qm, stride_qk = 8388608, 128, 4096, 1
86
+ stride_kz, stride_kh, stride_kn, stride_kk = 2097152, 262144, 128, 1
87
+ stride_vz, stride_vh, stride_vn, stride_vk = 2097152, 262144, 128, 1
88
+
89
+ ZQ = 2
90
+ HQ = 32
91
+ Q_LEN = 2048
92
+ ZKV = 2
93
+ KV_LEN = 2048
94
+
95
+ MATMUL_PRECISION = Q.dtype.element_ty
96
+
97
+ q_start = tl.program_id(0).to(INDEX_DTYPE)
98
+ off_zq = tl.program_id(1).to(INDEX_DTYPE)
99
+ off_hq = tl.program_id(2).to(INDEX_DTYPE)
100
+
101
+ # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq.
102
+ # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0.
103
+ off_zkv = off_zq % ZKV
104
+ off_hkv = off_hq // GQA_SHARED_HEADS
105
+ off_g = off_hq % GQA_SHARED_HEADS
106
+
107
+ q_offset = off_zq * stride_qz + off_hq * stride_qh
108
+ k_offset = off_zkv * stride_kz + off_hkv * stride_kh
109
+ v_offset = off_zkv * stride_vz + off_hkv * stride_vh
110
+
111
+ Q = Q + q_offset
112
+ K = K + k_offset
113
+ V = V + v_offset
114
+
115
+ # Setting up the TMA descriptors for Q, K, V
116
+ desc_q = None
117
+ desc_k = None
118
+ desc_v = None
119
+
120
+ SPARSE_Z = 2
121
+ SPARSE_HQ = 1
122
+
123
+ sparse_idx_z = off_zq % SPARSE_Z
124
+ sparse_idx_hq = off_hq % SPARSE_HQ
125
+
126
+ SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M)
127
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
128
+
129
+ stride_kv_num_blks_h = 16
130
+ stride_kv_idx_h = 256
131
+ stride_kv_idx_m = 16
132
+
133
+ # initialize pointer to m and l
134
+ m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
135
+ l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
136
+ acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
137
+
138
+ offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
139
+
140
+ # KV_IDX and KV_NUM_BLKS are always contiguous.
141
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq
142
+ sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE
143
+ sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950
144
+ offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
145
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
146
+ q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
147
+
148
+ # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
149
+ # We don't know anything "special" about these blocks, so we need to apply
150
+ # both score_mod and mask_mod to it
151
+ kv_indices = KV_IDX + sparse_kv_idx_offset
152
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
153
+ kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
154
+ block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
155
+
156
+
157
+ # K and V pointers will be passed directly to forward_inner
158
+
159
+ offs_n = kv_start + tl.arange(0, BLOCK_N)
160
+
161
+
162
+ acc, l_i, m_i = forward_inner(
163
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0,
164
+ q, K, V,
165
+ desc_k, desc_v, Q_LEN, KV_LEN,
166
+ acc, l_i, m_i,
167
+ off_zq, off_hq, offs_m[:, None], offs_n[None, :],
168
+ kv_start,
169
+ kv_indices, kv_num_blocks,
170
+ 0, block_n_end,
171
+ MATMUL_PRECISION,
172
+ stride_kk, stride_kn, stride_vn, stride_vk,
173
+ IS_FULL_BLOCKS=False,
174
+ )
175
+
176
+ # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
177
+ # We know these blocks are guaranteed to be "full", so we don't need to
178
+ # apply mask_mod to them - only score_mod
179
+ if HAS_FULL_BLOCKS:
180
+ # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
181
+ kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
182
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
183
+ kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
184
+ block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
185
+ # K and V pointers will be passed directly to forward_inner
186
+ offs_n = kv_start + tl.arange(0, BLOCK_N)
187
+
188
+ acc, l_i, m_i = forward_inner(
189
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0,
190
+ q, K, V,
191
+ desc_k, desc_v, Q_LEN, KV_LEN,
192
+ acc, l_i, m_i,
193
+ off_zq, off_hq, offs_m[:, None], offs_n[None, :],
194
+ kv_start,
195
+ kv_indices, kv_num_blocks,
196
+ 0, block_n_end,
197
+ MATMUL_PRECISION,
198
+ stride_kk, stride_kn, stride_vn, stride_vk,
199
+ IS_FULL_BLOCKS=True,
200
+ )
201
+
202
+
203
+ # [Note] Handle fully masked out rows:
204
+ # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf.
205
+ # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step
206
+ l_i = tl.where(l_i == 0.0, 1, l_i)
207
+
208
+ acc = acc / l_i[:, None]
209
+ idx_zq = tl.program_id(1).to(INDEX_DTYPE)
210
+ idx_hq = tl.program_id(2).to(INDEX_DTYPE)
211
+ idx_m = offs_m[:, None].to(INDEX_DTYPE)
212
+ idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE)
213
+
214
+ mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM)
215
+
216
+ tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED])
217
+ xindex = idx_d + 128*idx_m + 262144*idx_hq + 8388608*idx_zq
218
+ tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m + 8388608*idx_zq, acc.shape)), acc, mask)
219
+
220
+ if OUTPUT_LOGSUMEXP:
221
+ off_hz = off_zq * HQ + off_hq
222
+ l_ptrs = LSE + off_hz * Q_LEN + offs_m
223
+ lse = m_i + tl.math.log2(l_i)
224
+ if IS_DIVISIBLE:
225
+ tl.store(l_ptrs, lse)
226
+ else:
227
+ tl.store(l_ptrs, lse, mask=offs_m < Q_LEN)
228
+
229
+ if OUTPUT_MAX:
230
+ off_hz = off_zq * HQ + off_hq
231
+ max_ptrs = MAX + off_hz * Q_LEN + offs_m
232
+ if IS_DIVISIBLE:
233
+ tl.store(max_ptrs, m_i)
234
+ else:
235
+ tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN)
236
+
237
+
238
+ # Utility triton funcs
239
+ @triton.jit
240
+ def get_offset_for_next_block(
241
+ loop_iter, col_indices, total_blocks,
242
+ SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
243
+ BLOCKS_ARE_CONTIGUOUS: tl.constexpr
244
+ ):
245
+ if BLOCKS_ARE_CONTIGUOUS:
246
+ return BLOCK
247
+ cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
248
+ cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
249
+ next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
250
+ needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
251
+ jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
252
+ offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
253
+ return offset
254
+
255
+ @triton.jit
256
+ def get_bounded_indices(indices, max_len=None):
257
+ return indices % max_len if max_len is not None else indices
258
+
259
+ @triton.jit
260
+ def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
261
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
262
+ return tl.load(block_ptr)
263
+ elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
264
+ return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
265
+ elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
266
+ return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
267
+ else:
268
+ return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
269
+
270
+ @triton.jit
271
+ def load_checked_2d(
272
+ ptr,
273
+ offs_m,
274
+ offs_n,
275
+ stride_m,
276
+ stride_n,
277
+ IS_DIVISIBLE_M: tl.constexpr,
278
+ IS_DIVISIBLE_N: tl.constexpr,
279
+ M_LEN: tl.constexpr,
280
+ N_LEN: tl.constexpr,
281
+ ):
282
+ # Calculate final pointer if strides are provided
283
+ if stride_m is not None and stride_n is not None:
284
+ ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
285
+
286
+ # Handle all masking cases
287
+ if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
288
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
289
+ elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
290
+ return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
291
+ elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
292
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
293
+ else: # Both divisible
294
+ return tl.load(ptr)
295
+
296
+
297
+ # Common Imports
298
+ @triton.jit
299
+ def forward_block_mn(
300
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0,
301
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
302
+ # accumulated values
303
+ acc, l_i, m_i,
304
+ # Offsets
305
+ off_z, off_h, offs_m, offs_n,
306
+ # Offsets needed for TMA loads
307
+ kv_start,
308
+ kv_offset,
309
+ MATMUL_PRECISION, RCP_LN2,
310
+ # Strides for K and V
311
+ stride_kk, stride_kn, stride_vn, stride_vk,
312
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
313
+
314
+ ):
315
+ # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
316
+ PRESCALE_QK : tl.constexpr = False
317
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
318
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
319
+ WRITE_DQ : tl.constexpr = True
320
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
321
+ OUTPUT_MAX : tl.constexpr = False
322
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
323
+ IS_DIVISIBLE : tl.constexpr = True
324
+ SM_SCALE : tl.constexpr = 0.08838834764831843
325
+ GQA_SHARED_HEADS : tl.constexpr = 4
326
+ HAS_FULL_BLOCKS : tl.constexpr = True
327
+ QK_HEAD_DIM : tl.constexpr = 128
328
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
329
+ V_HEAD_DIM : tl.constexpr = 128
330
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
331
+ SAFE_HEAD_DIM : tl.constexpr = True
332
+ USE_TMA : tl.constexpr = False
333
+ BLOCK_M : tl.constexpr = 128
334
+ BLOCK_N : tl.constexpr = 64
335
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
336
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
337
+ INDEX_DTYPE : tl.constexpr = tl.int32
338
+
339
+
340
+ # -- load k --
341
+ # NB reversed order to since K is transposed
342
+ kv_base_offset = kv_start + kv_offset
343
+
344
+ # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N]
345
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
346
+ offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N)
347
+ k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
348
+
349
+ k = tl.trans(k)
350
+ # -- compute qk ---
351
+ qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2.
352
+ if not PRESCALE_QK:
353
+ qk *= SM_SCALE
354
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
355
+ # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements,
356
+ # which is larger than the actual number of elements. To avoid access memory out of bound,
357
+ # we need to mask out the elements that are out of Q_LEN & KV_LEN.
358
+ m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None)
359
+ n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None)
360
+
361
+ tmp0 = (qk)
362
+ post_mod_scores = tmp0
363
+
364
+
365
+ if CHECK_BLOCK_BOUNDARY:
366
+ # Mask out the elements that are out of the KV_LEN for non divisible seqlen.
367
+ post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf"))
368
+
369
+ if not IS_FULL_BLOCKS:
370
+ tmp1 = tl.full([1], False, tl.int1)
371
+ tmp2 = (m)
372
+ tmp3 = (n)
373
+ tmp4 = tmp2 >= tmp3
374
+ tmp5 = tmp3.to(tl.int64)
375
+ tmp6 = (off_z)
376
+ tmp7 = tl.load(in_ptr9 + tmp6)
377
+ tmp8 = tmp5 < tmp7
378
+ tmp9 = tmp2.to(tl.int64)
379
+ tmp10 = tmp9 < tmp7
380
+ tmp11 = tmp8 & tmp10
381
+ tmp12 = tmp4 & tmp11
382
+ tmp13 = tmp1 | tmp12
383
+ tmp14 = tl.full([1], 2048, tl.int32)
384
+ tmp15 = tmp3 >= tmp14
385
+ tmp16 = (tmp3 % tmp14)
386
+ tmp17 = tl.full([1], 0, tl.int32)
387
+ tmp18 = tmp16 != tmp17
388
+ tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0
389
+ tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0
390
+ tmp21 = tmp19 != tmp20
391
+ tmp22 = tmp18 & tmp21
392
+ tmp23 = tmp16 + tmp14
393
+ tmp24 = tl.where(tmp22, tmp23, tmp16)
394
+ tmp25 = tmp24.to(tl.int64)
395
+ tmp26 = tmp25 < tmp7
396
+ tmp27 = tmp15 & tmp26
397
+ tmp28 = tmp3 - tmp2
398
+ tmp29 = (tmp28 % tmp14)
399
+ tmp30 = tmp29 != tmp17
400
+ tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0
401
+ tmp32 = tmp31 != tmp20
402
+ tmp33 = tmp30 & tmp32
403
+ tmp34 = tmp29 + tmp14
404
+ tmp35 = tl.where(tmp33, tmp34, tmp29)
405
+ tmp36 = tmp35 == tmp17
406
+ tmp37 = tmp27 & tmp36
407
+ tmp38 = tmp13 | tmp37
408
+ mask_mod_output = tmp38
409
+
410
+
411
+ if CHECK_BLOCK_BOUNDARY:
412
+ mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False)
413
+ # apply mask for partially unmasked blocks
414
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
415
+
416
+ if not PRESCALE_QK:
417
+ post_mod_scores *= RCP_LN2
418
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
419
+
420
+ # -- compute scaling constant ---
421
+ m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1))
422
+ if not ROWS_GUARANTEED_SAFE:
423
+ masked_out_rows = (m_ij == float("-inf"))
424
+ m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
425
+ else:
426
+ m_ij_masked = m_ij
427
+
428
+ alpha = tl.math.exp2(m_i - m_ij_masked)
429
+ p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None])
430
+
431
+ # NB: l_i update is pulled up here since it's a bit faster
432
+ # NB: For headdim=256, it's faster to move it back down to after m_i =
433
+ # m_ij
434
+ l_i = l_i * alpha + tl.sum(p, 1)
435
+ # # -- scale and update acc --
436
+ acc = acc * alpha[:, None]
437
+ # Calculate offsets for V loading - reuse kv_base_offset from K loading
438
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
439
+ v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
440
+ acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION)
441
+
442
+ # -- update m_i
443
+ m_i = m_ij
444
+
445
+ return acc, l_i, m_i
446
+
447
+ @triton.jit
448
+ def forward_inner(
449
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0,
450
+ q, K, V,
451
+ desc_k, desc_v, Q_LEN, KV_LEN,
452
+ # accumulated values
453
+ acc, l_i, m_i,
454
+ # Offsets used as inputs to score_mod & mask_mod
455
+ # of size [BLOCK_M, BLOCK_N] or scalar.
456
+ off_z, off_h, offs_m, offs_n,
457
+ # Offsets needed for TMA loads
458
+ kv_start,
459
+ # blocksparse data
460
+ kv_indices, kv_num_blocks,
461
+ # start kv and end kv block
462
+ block_n_start, block_n_end,
463
+ MATMUL_PRECISION,
464
+ # Strides for K and V
465
+ stride_kk, stride_kn, stride_vn, stride_vk,
466
+ IS_FULL_BLOCKS,
467
+ ):
468
+ # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
469
+ PRESCALE_QK : tl.constexpr = False
470
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
471
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
472
+ WRITE_DQ : tl.constexpr = True
473
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
474
+ OUTPUT_MAX : tl.constexpr = False
475
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
476
+ IS_DIVISIBLE : tl.constexpr = True
477
+ SM_SCALE : tl.constexpr = 0.08838834764831843
478
+ GQA_SHARED_HEADS : tl.constexpr = 4
479
+ HAS_FULL_BLOCKS : tl.constexpr = True
480
+ QK_HEAD_DIM : tl.constexpr = 128
481
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
482
+ V_HEAD_DIM : tl.constexpr = 128
483
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
484
+ SAFE_HEAD_DIM : tl.constexpr = True
485
+ USE_TMA : tl.constexpr = False
486
+ BLOCK_M : tl.constexpr = 128
487
+ BLOCK_N : tl.constexpr = 64
488
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
489
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
490
+ INDEX_DTYPE : tl.constexpr = tl.int32
491
+
492
+
493
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
494
+ RCP_LN2: tl.constexpr = 1.44269504
495
+
496
+ if PRESCALE_QK:
497
+ q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
498
+
499
+ kv_offset = 0
500
+
501
+ # loop over k, v and update accumulator until block_n_end
502
+ for start_n in range(block_n_start, block_n_end):
503
+ # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention.
504
+ if IS_DIVISIBLE:
505
+ acc, l_i, m_i = forward_block_mn(
506
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0,
507
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
508
+ # accumulated values
509
+ acc, l_i, m_i,
510
+ # Offsets
511
+ off_z, off_h, offs_m, offs_n,
512
+ # Offsets needed for TMA loads
513
+ kv_start,
514
+ kv_offset,
515
+ MATMUL_PRECISION, RCP_LN2,
516
+ # Strides for K and V
517
+ stride_kk, stride_kn, stride_vn, stride_vk,
518
+ IS_FULL_BLOCKS,
519
+ )
520
+ else:
521
+ # Benchmark shows even we applied mod & mask to each block for non divisible seqlen,
522
+ # it's on par or slightly faster than only applying to the last block in fwd.
523
+ # However, we choose different strategy for bwd, where we only apply mod & mask
524
+ # to the last block because it's faster a lot.
525
+ acc, l_i, m_i = forward_block_mn(
526
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0,
527
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
528
+ # accumulated values
529
+ acc, l_i, m_i,
530
+ # Offsets
531
+ off_z, off_h, offs_m, offs_n,
532
+ # Offsets needed for TMA loads
533
+ kv_start,
534
+ kv_offset,
535
+ MATMUL_PRECISION, RCP_LN2,
536
+ # Strides for K and V
537
+ stride_kk, stride_kn, stride_vn, stride_vk,
538
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
539
+ )
540
+
541
+
542
+
543
+ offset = get_offset_for_next_block(
544
+ start_n, kv_indices, kv_num_blocks,
545
+ SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS
546
+ )
547
+
548
+ offs_n = offs_n + offset
549
+ kv_offset += offset
550
+
551
+
552
+ return acc, l_i, m_i
SpecForge-ext/cache/compiled_kernels/fh/cfhmsnuqfbjggcp2r4forretj7wzvobbq6w5hy337y6tmciawqkk.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.reduction(
11
+ size_hints={'x': 32, 'r0_': 32},
12
+ reduction_hint=ReductionHint.INNER,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_ptr0': '*i32', 'in_ptr1': '*i64', 'out_ptr1': '*i32', 'out_ptr2': '*i32', 'out_ptr3': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2', 'mutated_arg_names': ['out_ptr3'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
16
+ )
17
+ @triton.jit
18
+ def triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2(in_ptr0, in_ptr1, out_ptr1, out_ptr2, out_ptr3, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
19
+ xnumel = 32
20
+ rnumel = r0_numel
21
+ RBLOCK: tl.constexpr = R0_BLOCK
22
+ xoffset = tl.program_id(0) * XBLOCK
23
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
24
+ xmask = xindex < xnumel
25
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
26
+ rbase = r0_base
27
+ x0 = xindex
28
+ _tmp3 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64)
29
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
30
+ r0_index = r0_offset + r0_base
31
+ r0_mask = r0_index < r0_numel
32
+ roffset = r0_offset
33
+ rindex = r0_index
34
+ r0_1 = r0_index
35
+ tmp0 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0)
36
+ tmp1 = tmp0.to(tl.int64)
37
+ tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK])
38
+ tmp4 = _tmp3 + tmp2
39
+ _tmp3 = tl.where(r0_mask & xmask, tmp4, _tmp3)
40
+ tmp3 = tl.sum(_tmp3, 1)[:, None]
41
+ tmp5 = tmp3.to(tl.int32)
42
+ tl.store(out_ptr1 + (x0), tmp5, xmask)
43
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
44
+ r0_index = r0_offset + r0_base
45
+ r0_mask = r0_index < r0_numel
46
+ roffset = r0_offset
47
+ rindex = r0_index
48
+ r0_1 = r0_index
49
+ tmp6 = tl.load(in_ptr1 + (r0_1 + x0*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), r0_mask & xmask, eviction_policy='evict_first', other=0.0)
50
+ tmp7 = tmp6.to(tl.int32)
51
+ tmp8 = r0_1
52
+ tmp9 = tmp8 < tmp5
53
+ tmp10 = ks0
54
+ tmp11 = tl.where(tmp9, tmp7, tmp10)
55
+ tmp12 = 1 + ks0
56
+ tmp13 = tmp11 + tmp12
57
+ tmp14 = tmp11 < 0
58
+ tmp15 = tl.where(tmp14, tmp13, tmp11)
59
+ tl.device_assert(((0 <= tmp15) & (tmp15 < 1 + (triton_helpers.div_floor_integer(127 + ks1, 128)))) | ~(r0_mask & xmask), "index out of bounds: 0 <= tmp15 < 1 + (triton_helpers.div_floor_integer(127 + ks1, 128))")
60
+ tmp17 = tl.full([1, 1], 1, tl.int32)
61
+ tl.store(out_ptr2 + (r0_1 + x0*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp7, r0_mask & xmask)
62
+ tl.store(out_ptr3 + (tl.broadcast_to(tmp15 + x0 + ks0*x0, [XBLOCK, R0_BLOCK])), tmp17, r0_mask & xmask)
SpecForge-ext/cache/compiled_kernels/fh/ebe6017c015020b128565a146c63c01eb1d20ffe6e82484e1c26bb63be24756a.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 1, "num_warps": 2, "num_stages": 1, "configs_hash": "6fcabd0411a839b7b5d117b5e6638bd1b5d7bc3379312c678d803859f08278a9", "found_by_coordesc": false, "time_taken_ms": 26, "triton_cache_hash": "NFXDDTZIEQAGR5JBWWCXV73QZILXJMIVVJVPZH3CHAIULDAND5UQ"}
SpecForge-ext/cache/compiled_kernels/fl/cfl7aqky4mcwhud5rcyx5e6sredhx2vbbrykfa5v67vwkgveygd5.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AOT ID: ['1_inference']
2
+ from ctypes import c_void_p, c_long, c_int
3
+ import torch
4
+ import math
5
+ import random
6
+ import os
7
+ import tempfile
8
+ from math import inf, nan
9
+ from cmath import nanj
10
+ from torch._inductor.hooks import run_intermediate_hooks
11
+ from torch._inductor.utils import maybe_profile
12
+ from torch._inductor.codegen.memory_planning import _align as align
13
+ from torch import device, empty_strided
14
+ from torch._inductor.async_compile import AsyncCompile
15
+ from torch._inductor.select_algorithm import extern_kernels
16
+ import triton
17
+ import triton.language as tl
18
+ from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
19
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
20
+
21
+ aten = torch.ops.aten
22
+ inductor_ops = torch.ops.inductor
23
+ _quantized = torch.ops._quantized
24
+ assert_size_stride = torch._C._dynamo.guards.assert_size_stride
25
+ assert_alignment = torch._C._dynamo.guards.assert_alignment
26
+ empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
27
+ empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
28
+ empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
29
+ empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
30
+ empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
31
+ reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
32
+ alloc_from_pool = torch.ops.inductor._alloc_from_pool
33
+ async_compile = AsyncCompile()
34
+ empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
35
+
36
+
37
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/mm/cmmp5cb4b4xchyyotwouonjdn4i7oimojhwosocnjqx2t5kcq5jf.py
38
+ # Topologically Sorted Source Nodes: [target_head, target_p], Original ATen: [aten._to_copy, prims.prepare_softmax_online, aten.sub, aten.exp, aten._softmax]
39
+ # Source node to ATen node mapping:
40
+ # target_head => convert_element_type
41
+ # target_p => div
42
+ # Graph fragment:
43
+ # %arg0_1 : Tensor "bf16[8, 2048, 32000][65536000, 32000, 1]cuda:5" = PlaceHolder[target=arg0_1]
44
+ # %getitem : Tensor "f32[8, 2048, 1][2048, 1, 16384]cuda:5" = PlaceHolder[target=getitem]
45
+ # %getitem_1 : Tensor "f32[8, 2048, 1][2048, 1, 16384]cuda:5" = PlaceHolder[target=getitem_1]
46
+ # %convert_element_type : Tensor "f32[8, 2048, 32000][65536000, 32000, 1]cuda:5"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%arg0_1, torch.float32), kwargs = {})
47
+ # %prepare_softmax_online_default : [num_users=2] = call_function[target=torch.ops.prims.prepare_softmax_online.default](args = (%convert_element_type, 2), kwargs = {})
48
+ # %sub_tensor : Tensor "f32[8, 2048, 32000][65536000, 32000, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%convert_element_type, %getitem), kwargs = {})
49
+ # %exp_default : Tensor "f32[8, 2048, 32000][65536000, 32000, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.exp.default](args = (%sub_tensor,), kwargs = {})
50
+ # %div : Tensor "f32[8, 2048, 32000][65536000, 32000, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%exp_default, %getitem_1), kwargs = {})
51
+ # return %getitem,%getitem_1,%div
52
+ triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0 = async_compile.triton('triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0', '''
53
+ import triton
54
+ import triton.language as tl
55
+
56
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
57
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
58
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
59
+ triton_helpers.set_driver_to_gpu()
60
+
61
+ @triton_heuristics.reduction(
62
+ size_hints={'x': 16384, 'r0_': 32768},
63
+ reduction_hint=ReductionHint.INNER,
64
+ filename=__file__,
65
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr2': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
66
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'add_persistent_rblock': True, 'tiling_scores': {'x': 0, 'r0_': 5242880000}}
67
+ )
68
+ @triton.jit
69
+ def triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0(in_ptr0, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
70
+ xnumel = 16384
71
+ r0_numel = 32000
72
+ rnumel = r0_numel
73
+ RBLOCK: tl.constexpr = R0_BLOCK
74
+ xoffset = tl.program_id(0) * XBLOCK
75
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
76
+ xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
77
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
78
+ rbase = r0_base
79
+ x0 = xindex
80
+ _tmp3_max = tl.full([XBLOCK, R0_BLOCK], float('-inf'), tl.float32)
81
+ _tmp3_sum = tl.zeros([XBLOCK, R0_BLOCK], tl.float32)
82
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
83
+ r0_index = r0_offset + r0_base
84
+ r0_mask = r0_index < r0_numel
85
+ roffset = r0_offset
86
+ rindex = r0_index
87
+ r0_1 = r0_index
88
+ tmp0 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
89
+ tmp1 = tmp0.to(tl.float32)
90
+ tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK])
91
+
92
+ _tmp3_max_next, _tmp3_sum_next = triton_helpers.online_softmax_combine(
93
+ _tmp3_max, _tmp3_sum, tmp2, False
94
+ )
95
+
96
+ _tmp3_max = tl.where(r0_mask, _tmp3_max_next, _tmp3_max)
97
+ _tmp3_sum = tl.where(r0_mask, _tmp3_sum_next, _tmp3_sum)
98
+
99
+ tmp3, tmp4 = triton_helpers.online_softmax_reduce(
100
+ _tmp3_max, _tmp3_sum, 1, False)
101
+ tmp3 = tmp3[:, None]
102
+ tmp4 = tmp4[:, None]
103
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
104
+ r0_index = r0_offset + r0_base
105
+ r0_mask = r0_index < r0_numel
106
+ roffset = r0_offset
107
+ rindex = r0_index
108
+ r0_1 = r0_index
109
+ tmp5 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
110
+ tmp6 = tmp5.to(tl.float32)
111
+ tmp7 = tmp6 - tmp3
112
+ tmp8 = libdevice.exp(tmp7)
113
+ tmp9 = (tmp8 / tmp4)
114
+ tl.store(out_ptr2 + (r0_1 + 32000*x0), tmp9, r0_mask)
115
+ ''', device_str='cuda')
116
+
117
+
118
+ async_compile.wait(globals())
119
+ del async_compile
120
+
121
+ class Runner:
122
+ def __init__(self, partitions):
123
+ self.partitions = partitions
124
+
125
+ def recursively_apply_fns(self, fns):
126
+ new_callables = []
127
+ for fn, c in zip(fns, self.partitions):
128
+ new_callables.append(fn(c))
129
+ self.partitions = new_callables
130
+
131
+ def call(self, args):
132
+ arg0_1, = args
133
+ args.clear()
134
+ assert_size_stride(arg0_1, (8, 2048, 32000), (65536000, 32000, 1))
135
+ with torch.cuda._DeviceGuard(5):
136
+ torch.cuda.set_device(5)
137
+ buf2 = empty_strided_cuda((8, 2048, 32000), (65536000, 32000, 1), torch.float32)
138
+ # Topologically Sorted Source Nodes: [target_head, target_p], Original ATen: [aten._to_copy, prims.prepare_softmax_online, aten.sub, aten.exp, aten._softmax]
139
+ stream5 = get_raw_stream(5)
140
+ triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0.run(arg0_1, buf2, 16384, 32000, stream=stream5)
141
+ del arg0_1
142
+ return (buf2, )
143
+
144
+ runner = Runner(partitions=[])
145
+ call = runner.call
146
+ recursively_apply_fns = runner.recursively_apply_fns
147
+
148
+
149
+ def benchmark_compiled_module(times=10, repeat=10):
150
+ from torch._dynamo.testing import rand_strided
151
+ from torch._inductor.utils import print_performance
152
+ arg0_1 = rand_strided((8, 2048, 32000), (65536000, 32000, 1), device='cuda:5', dtype=torch.bfloat16)
153
+ fn = lambda: call([arg0_1])
154
+ return print_performance(fn, times=times, repeat=repeat)
155
+
156
+
157
+ if __name__ == "__main__":
158
+ from torch._inductor.wrapper_benchmark import compiled_module_main
159
+ compiled_module_main('None', benchmark_compiled_module)
SpecForge-ext/cache/compiled_kernels/gn/cgnmjxikvi5ulcyj3uozif3le5hd26kw2kjhkcbhupqgudqi3bwn.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.reduction(
11
+ size_hints={'x': 4096, 'r0_': 4096},
12
+ reduction_hint=ReductionHint.INNER,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_div_expand_mul_pow_sum_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 7, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
16
+ )
17
+ @triton.jit
18
+ def triton_red_fused__to_copy_add_div_expand_mul_pow_sum_2(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
19
+ rnumel = r0_numel
20
+ RBLOCK: tl.constexpr = R0_BLOCK
21
+ xoffset = tl.program_id(0) * XBLOCK
22
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
23
+ xmask = xindex < xnumel
24
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
25
+ rbase = r0_base
26
+ x0 = xindex
27
+ _tmp8 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
28
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
29
+ r0_index = r0_offset + r0_base
30
+ r0_mask = r0_index < r0_numel
31
+ roffset = r0_offset
32
+ rindex = r0_index
33
+ r0_1 = r0_index
34
+ tmp0 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
35
+ tmp1 = tl.load(in_ptr1 + (r0_1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
36
+ tmp4 = tl.load(in_ptr2 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
37
+ tmp2 = tmp0 * tmp1
38
+ tmp3 = tmp2.to(tl.float32)
39
+ tmp5 = tmp4.to(tl.float32)
40
+ tmp6 = tmp3 * tmp5
41
+ tmp7 = tl.broadcast_to(tmp6, [XBLOCK, R0_BLOCK])
42
+ tmp9 = _tmp8 + tmp7
43
+ _tmp8 = tl.where(r0_mask & xmask, tmp9, _tmp8)
44
+ tmp8 = tl.sum(_tmp8, 1)[:, None]
45
+ tmp14 = tl.load(in_ptr3 + (x0), xmask, eviction_policy='evict_last')
46
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
47
+ r0_index = r0_offset + r0_base
48
+ r0_mask = r0_index < r0_numel
49
+ roffset = r0_offset
50
+ rindex = r0_index
51
+ r0_1 = r0_index
52
+ tmp10 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
53
+ tmp11 = tl.load(in_ptr1 + (r0_1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
54
+ tmp24 = tl.load(in_ptr2 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
55
+ tmp12 = tmp10 * tmp11
56
+ tmp13 = tmp12.to(tl.float32)
57
+ tmp15 = tmp13 * tmp14
58
+ tmp16 = -0.5
59
+ tmp17 = tmp8 * tmp16
60
+ tmp18 = tmp14 * tmp14
61
+ tmp19 = tmp18 * tmp14
62
+ tmp20 = tmp17 * tmp19
63
+ tmp21 = ks0
64
+ tmp22 = tmp21.to(tl.float32)
65
+ tmp23 = (tmp20 / tmp22)
66
+ tmp25 = tmp24.to(tl.float32)
67
+ tmp26 = 2.0
68
+ tmp27 = tmp25 * tmp26
69
+ tmp28 = tmp23 * tmp27
70
+ tmp29 = tmp15 + tmp28
71
+ tmp30 = tmp29.to(tl.float32)
72
+ tl.store(out_ptr1 + (r0_1 + ks0*x0), tmp30, r0_mask & xmask)
SpecForge-ext/cache/compiled_kernels/gn/cgnsrigp6qu2lbqq76g27kshvt2bzkyjnupza5ds7znhjxrnwhif.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.persistent_reduction(
11
+ size_hints={'x': 256, 'r0_': 16},
12
+ reduction_hint=ReductionHint.DEFAULT,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr2': '*i32', 'out_ptr3': '*i32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
16
+ )
17
+ @triton.jit
18
+ def triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3(in_ptr0, out_ptr2, out_ptr3, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr):
19
+ r0_numel = 16
20
+ R0_BLOCK: tl.constexpr = 16
21
+ rnumel = r0_numel
22
+ RBLOCK: tl.constexpr = R0_BLOCK
23
+ xoffset = tl.program_id(0) * XBLOCK
24
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
25
+ xmask = xindex < xnumel
26
+ r0_index = tl.arange(0, R0_BLOCK)[None, :]
27
+ r0_offset = 0
28
+ r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
29
+ roffset = r0_offset
30
+ rindex = r0_index
31
+ r0_2 = r0_index
32
+ x0 = (xindex % ks0)
33
+ x1 = xindex // ks0
34
+ x3 = xindex
35
+ tmp0 = tl.load(in_ptr0 + (r0_2 + x0 + 16*x1 + ks0*r0_2 + 16*ks0*x1), xmask, eviction_policy='evict_last', other=0.0)
36
+ tmp1 = r0_2
37
+ tmp2 = tmp1.to(tl.int16)
38
+ tmp3 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
39
+ tmp4 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK])
40
+ tmp5, tmp6, = triton_helpers.sort_with_index(tmp3, tmp4, None, 1, stable=True, descending=True)
41
+ tmp7 = tmp0.to(tl.int64)
42
+ tmp8 = tl.broadcast_to(tmp7, [XBLOCK, R0_BLOCK])
43
+ tmp10 = tl.where(xmask, tmp8, 0)
44
+ tmp11 = tl.sum(tmp10, 1)[:, None].to(tl.int64)
45
+ tmp12 = tmp6.to(tl.int64)
46
+ tmp13 = tmp12.to(tl.int32)
47
+ tmp14 = tmp11.to(tl.int32)
48
+ tl.store(out_ptr2 + (r0_2 + 16*x0 + 16*x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp13, xmask)
49
+ tl.store(out_ptr3 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp14, xmask)
SpecForge-ext/cache/compiled_kernels/gv/cgva67py5joafltlxqsoz5uf2a7qh2rakl35e3wsc4nbdlv75anq.py ADDED
@@ -0,0 +1,835 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+
9
+ @triton_heuristics.template(
10
+
11
+ num_stages=3,
12
+ num_warps=8,
13
+ triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32', 'ks6': 'i32', 'ks7': 'i32', 'ks8': 'i32'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'kernel_name': 'triton_tem_fused_zeros_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
15
+
16
+ )
17
+ @triton.jit
18
+ def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8):
19
+ PRESCALE_QK : tl.constexpr = False
20
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
21
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
22
+ WRITE_DQ : tl.constexpr = True
23
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
24
+ OUTPUT_MAX : tl.constexpr = False
25
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
26
+ IS_DIVISIBLE : tl.constexpr = False
27
+ SM_SCALE : tl.constexpr = 0.08838834764831843
28
+ GQA_SHARED_HEADS : tl.constexpr = 4
29
+ HAS_FULL_BLOCKS : tl.constexpr = True
30
+ QK_HEAD_DIM : tl.constexpr = 128
31
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
32
+ V_HEAD_DIM : tl.constexpr = 128
33
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
34
+ SAFE_HEAD_DIM : tl.constexpr = True
35
+ BLOCK_M1 : tl.constexpr = 64
36
+ BLOCK_N1 : tl.constexpr = 128
37
+ BLOCK_M2 : tl.constexpr = 128
38
+ BLOCK_N2 : tl.constexpr = 64
39
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
40
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
41
+ INDEX_DTYPE : tl.constexpr = tl.int32
42
+ Q = arg_Q
43
+ K = arg_K
44
+ V = arg_V
45
+ LSE = arg_LSE
46
+ DELTA = arg_DELTA
47
+ DO = arg_DO
48
+ DQ = arg_DQ
49
+ DV = arg_DV
50
+ KV_NUM_BLKS = arg_KV_NUM_BLKS
51
+ KV_IDX = arg_KV_IDX
52
+ Q_NUM_BLKS = arg_Q_NUM_BLKS
53
+ Q_IDX = arg_Q_IDX
54
+ FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
55
+ FULL_KV_IDX = arg_FULL_KV_IDX
56
+ FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS
57
+ FULL_Q_IDX = arg_FULL_Q_IDX
58
+
59
+ # Sub notation for this kernel:
60
+ #
61
+ # Q: Query, K: Key, V: Value
62
+ # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype)
63
+ # DELTA: Precomputed sum(OUT*DO, axis=-1)
64
+ # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value
65
+ # DK: Derivative of Key, is the written to via the store_output call due to some limitations with
66
+ # inductor codegen
67
+ # M: Number of queries, N: Number of keys/values
68
+ # QK_HEAD_DIM: The dimension of the query and key embeddings
69
+ # V_HEAD_DIM: The dimension of the value embeddings
70
+ # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim
71
+ # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
72
+ # (Modifiable) Performance tuning options
73
+ # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block.
74
+ # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V.
75
+ # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q.
76
+ # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block.
77
+ #
78
+ # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
79
+ # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
80
+ # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
81
+ # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query.
82
+ # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query.
83
+ # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
84
+ # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
85
+ # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query.
86
+ # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query.
87
+
88
+ # The below are kernel options that can be applied for certain score_mods,
89
+ # or involve a numerics vs. perf tradeoff
90
+ # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
91
+ # about 20% more numerical error, but slightly faster.
92
+
93
+ # Define strides of inputs
94
+ stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1
95
+ stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128*ks1, 128, 1
96
+ stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128*ks1, 128, 1
97
+ stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1
98
+
99
+ stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1
100
+ stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128*ks1, 128, 1
101
+
102
+ ZQ = 2
103
+ HQ = 32
104
+ HKV = 8
105
+ Q_LEN = ks0
106
+ ZKV = 2
107
+ KV_LEN = ks1
108
+
109
+ MATMUL_PRECISION = Q.dtype.element_ty
110
+
111
+ pid = tl.program_id(0).to(INDEX_DTYPE)
112
+ NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1)
113
+ NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2)
114
+
115
+ off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx
116
+ off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx
117
+ off_zkv = off_zq % ZKV # kv batch idx
118
+
119
+ SPARSE_Z = 2
120
+ SPARSE_HQ = 1
121
+
122
+ sparse_idx_z = off_zq % SPARSE_Z
123
+
124
+ k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64)
125
+ v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64)
126
+ # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
127
+ # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
128
+ dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64)
129
+
130
+ # offset K, V, DV pointers for batch/kv-head
131
+ K += k_adj
132
+ V += v_adj
133
+ DV += dv_adj
134
+
135
+ RCP_LN2 = 1.44269504
136
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
137
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
138
+
139
+ if pid >= NUM_KV_BLOCKS:
140
+ off_pid = pid - NUM_KV_BLOCKS
141
+ # THIS BLOCK DOES DQ
142
+ SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2)
143
+ SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
144
+ off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS
145
+ start_m2_block = off_pid % NUM_Q_BLOCKS
146
+ off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE
147
+ stride_kv_num_blks_h = ks2
148
+ stride_kv_idx_h = ks3*ks4
149
+ stride_kv_idx_m = ks4
150
+
151
+ sparse_idx_hq2 = off_hq2 % SPARSE_HQ
152
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2
153
+
154
+ sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask
155
+ sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950
156
+
157
+ # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
158
+ q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64)
159
+ do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64)
160
+ dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64)
161
+ off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64)
162
+
163
+ Q2 = Q + q_adj2
164
+ DO2 = DO + do_adj2
165
+ # TODO: This does not work if DQ is not the same layout as Q (for example,
166
+ # if Q is broadcasted)
167
+ DQ2 = DQ + dq_adj2
168
+ LSE2 = LSE + off_chz2
169
+ DELTA2 = DELTA + off_chz2
170
+
171
+ # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32)
172
+ dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
173
+
174
+ start_m2 = start_m2_block * BLOCK_M2
175
+ offs_m2 = start_m2 + tl.arange(0, BLOCK_M2)
176
+
177
+ # load Q and do: they stay in SRAM throughout the inner loop.
178
+ q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
179
+ do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
180
+
181
+ if PRESCALE_QK:
182
+ q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
183
+
184
+ if IS_DIVISIBLE:
185
+ Di = tl.load(DELTA2 + offs_m2)
186
+ lse = tl.load(LSE2 + offs_m2)
187
+ else:
188
+ Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN)
189
+ lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN)
190
+ lse = tl.where(lse == -float("inf"), 0.0, lse)
191
+ lse = lse[:, None]
192
+
193
+ # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
194
+ # KV_IDX and KV_NUM_BLKS are always contiguous.
195
+ kv_indices = KV_IDX + sparse_kv_idx_offset
196
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
197
+ sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
198
+
199
+ offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
200
+ dq = bwd_dq_inner(
201
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8,
202
+ K, V,
203
+ dq, q, do, Di, lse,
204
+ off_zq, off_hq2, offs_m2, offs_n2,
205
+ stride_kn, stride_kd, stride_vn, stride_vd,
206
+ kv_indices, sparse_kv_num_blocks,
207
+ MATMUL_PRECISION,
208
+ IS_FULL_BLOCKS=False,
209
+ )
210
+
211
+ if HAS_FULL_BLOCKS:
212
+ # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
213
+ # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
214
+ kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
215
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
216
+ sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
217
+
218
+ offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
219
+ dq = bwd_dq_inner(
220
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8,
221
+ K, V,
222
+ dq, q, do, Di, lse,
223
+ off_zq, off_hq2, offs_m2, offs_n2,
224
+ stride_kn, stride_kd, stride_vn, stride_vd,
225
+ kv_indices, sparse_kv_num_blocks,
226
+ MATMUL_PRECISION,
227
+ IS_FULL_BLOCKS=True,
228
+ )
229
+
230
+ # Write back dQ.
231
+ dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd
232
+ dq *= SM_SCALE
233
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
234
+ tl.store(dq_ptrs, dq)
235
+ else:
236
+ tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM))
237
+ else:
238
+ # THIS BLOCK DOES DK & DV
239
+ SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
240
+ SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1)
241
+
242
+ pid_mask = pid // SPARSE_KV_MULTIPLE
243
+
244
+ stride_q_num_blks_h = ks5
245
+ stride_q_idx_h = ks6*ks7
246
+ stride_q_idx_n = ks6
247
+
248
+
249
+ dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
250
+ dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
251
+
252
+ start_n1 = pid * BLOCK_N1
253
+ offs_n1 = start_n1 + tl.arange(0, BLOCK_N1)
254
+
255
+ # load K and V: they stay in SRAM throughout the inner loop.
256
+ k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
257
+ v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
258
+
259
+ if PRESCALE_QK:
260
+ k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
261
+
262
+ for off_g in range(0, GQA_SHARED_HEADS):
263
+ off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g
264
+
265
+ # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
266
+ q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64)
267
+ do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64)
268
+ dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64)
269
+ off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64)
270
+
271
+ Q1 = Q + q_adj1
272
+ DO1 = DO + do_adj1
273
+ # TODO: This does not work if DQ is not the same layout as Q (for example,
274
+ # if Q is broadcasted)
275
+ LSE1 = LSE + off_chz1
276
+ DELTA1 = DELTA + off_chz1
277
+
278
+ sparse_idx_hq1 = off_hq1 % SPARSE_HQ
279
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1
280
+
281
+ sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask
282
+ sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950
283
+
284
+ # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
285
+ # Q_IDX and Q_NUM_BLKS are always contiguous.
286
+ q_indices = Q_IDX + sparse_q_idx_offset
287
+ q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
288
+ sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset)
289
+
290
+ offs_m1 = q_start + tl.arange(0, BLOCK_M1)
291
+ dk, dv = bwd_dkdv_inner(
292
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8,
293
+ Q1, DO1, DELTA1, LSE1,
294
+ dk, dv, k, v,
295
+ off_zq, off_hq1, offs_n1, offs_m1,
296
+ stride_qm, stride_qd, stride_dom, stride_dod,
297
+ q_indices, sparse_q_num_blocks,
298
+ MATMUL_PRECISION,
299
+ IS_FULL_BLOCKS=False,
300
+ )
301
+
302
+
303
+ if HAS_FULL_BLOCKS:
304
+ # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
305
+ # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous.
306
+ q_indices = FULL_Q_IDX + sparse_q_idx_offset
307
+ q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
308
+ sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset)
309
+
310
+ offs_m1 = q_start + tl.arange(0, BLOCK_M1)
311
+ dk, dv = bwd_dkdv_inner(
312
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8,
313
+ Q1, DO1, DELTA1, LSE1,
314
+ dk, dv, k, v,
315
+ off_zq, off_hq1, offs_n1, offs_m1,
316
+ stride_qm, stride_qd, stride_dom, stride_dod,
317
+ q_indices, sparse_q_num_blocks,
318
+ MATMUL_PRECISION,
319
+ IS_FULL_BLOCKS=True,
320
+ )
321
+
322
+ # Write back dV and dK.
323
+ dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd
324
+
325
+ index_n = offs_n1[:, None]
326
+ index_k = offs_k[None, :]
327
+ index_v = offs_v[None, :]
328
+
329
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
330
+ tl.store(dv_ptrs, dv)
331
+ else:
332
+ tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM))
333
+
334
+ dk *= SM_SCALE
335
+
336
+ if SAFE_HEAD_DIM:
337
+ mask = index_n < KV_LEN
338
+ else:
339
+ mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM)
340
+
341
+ # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
342
+ # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
343
+ tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED])
344
+ xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1
345
+ tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask)
346
+
347
+ @triton.jit
348
+ def bwd_dq_inner(
349
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8,
350
+ K, V, # pointers
351
+ dq, q, do, Di, lse,
352
+ off_z, off_hq, offs_m2, offs_n2,
353
+ stride_kn, stride_kd, stride_vn, stride_vd,
354
+ kv_indices, sparse_kv_num_blocks,
355
+ MATMUL_PRECISION,
356
+ IS_FULL_BLOCKS,
357
+ ):
358
+ PRESCALE_QK : tl.constexpr = False
359
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
360
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
361
+ WRITE_DQ : tl.constexpr = True
362
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
363
+ OUTPUT_MAX : tl.constexpr = False
364
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
365
+ IS_DIVISIBLE : tl.constexpr = False
366
+ SM_SCALE : tl.constexpr = 0.08838834764831843
367
+ GQA_SHARED_HEADS : tl.constexpr = 4
368
+ HAS_FULL_BLOCKS : tl.constexpr = True
369
+ QK_HEAD_DIM : tl.constexpr = 128
370
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
371
+ V_HEAD_DIM : tl.constexpr = 128
372
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
373
+ SAFE_HEAD_DIM : tl.constexpr = True
374
+ BLOCK_M1 : tl.constexpr = 64
375
+ BLOCK_N1 : tl.constexpr = 128
376
+ BLOCK_M2 : tl.constexpr = 128
377
+ BLOCK_N2 : tl.constexpr = 64
378
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
379
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
380
+ INDEX_DTYPE : tl.constexpr = tl.int32
381
+
382
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
383
+ RCP_LN2: tl.constexpr = 1.44269504
384
+ Q_LEN = ks0
385
+ KV_LEN = ks1
386
+
387
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
388
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
389
+
390
+ kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd
391
+ vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd
392
+ # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
393
+ tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
394
+
395
+ hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1))
396
+
397
+ for start_n in range(0, hi):
398
+ dq = bwd_dq_block_mn(
399
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8,
400
+ dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
401
+ off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
402
+ stride_kn, stride_kd, stride_vn, stride_vd,
403
+ kv_indices, sparse_kv_num_blocks,
404
+ MATMUL_PRECISION, RCP_LN2,
405
+ IS_FULL_BLOCKS,
406
+ )
407
+
408
+ # Increment pointers.
409
+ offset = get_offset_for_next_block(
410
+ start_n, kv_indices, sparse_kv_num_blocks,
411
+ SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS
412
+ )
413
+
414
+ kT_ptrs += offset * stride_kn
415
+ vT_ptrs += offset * stride_vn
416
+
417
+ offs_n2 += offset
418
+
419
+ return dq
420
+
421
+
422
+ @triton.jit
423
+ def bwd_dq_block_mn(
424
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8,
425
+ dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
426
+ off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
427
+ stride_kn, stride_kd, stride_vn, stride_vd,
428
+ kv_indices, sparse_kv_num_blocks,
429
+ MATMUL_PRECISION, RCP_LN2,
430
+ IS_FULL_BLOCKS,
431
+ ):
432
+ PRESCALE_QK : tl.constexpr = False
433
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
434
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
435
+ WRITE_DQ : tl.constexpr = True
436
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
437
+ OUTPUT_MAX : tl.constexpr = False
438
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
439
+ IS_DIVISIBLE : tl.constexpr = False
440
+ SM_SCALE : tl.constexpr = 0.08838834764831843
441
+ GQA_SHARED_HEADS : tl.constexpr = 4
442
+ HAS_FULL_BLOCKS : tl.constexpr = True
443
+ QK_HEAD_DIM : tl.constexpr = 128
444
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
445
+ V_HEAD_DIM : tl.constexpr = 128
446
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
447
+ SAFE_HEAD_DIM : tl.constexpr = True
448
+ BLOCK_M1 : tl.constexpr = 64
449
+ BLOCK_N1 : tl.constexpr = 128
450
+ BLOCK_M2 : tl.constexpr = 128
451
+ BLOCK_N2 : tl.constexpr = 64
452
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
453
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
454
+ INDEX_DTYPE : tl.constexpr = tl.int32
455
+
456
+
457
+ # NB reversed order to since K is transposed
458
+ kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN)
459
+ qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION)
460
+ if not PRESCALE_QK:
461
+ qk *= SM_SCALE
462
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
463
+ pre_mod_scores = qk
464
+ n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None)
465
+ # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim
466
+ # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary
467
+ m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None)
468
+
469
+ tmp0 = (qk)
470
+ post_mod_scores = tmp0
471
+
472
+
473
+
474
+
475
+ if not IS_DIVISIBLE:
476
+ post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf"))
477
+
478
+ if not IS_FULL_BLOCKS:
479
+ tmp1 = tl.full([1], False, tl.int1)
480
+ tmp2 = (m)
481
+ tmp3 = (n)
482
+ tmp4 = tmp2 >= tmp3
483
+ tmp5 = tmp3.to(tl.int64)
484
+ tmp6 = (off_z)
485
+ tmp7 = tl.load(in_ptr16 + tmp6)
486
+ tmp8 = tmp5 < tmp7
487
+ tmp9 = tmp2.to(tl.int64)
488
+ tmp10 = tmp9 < tmp7
489
+ tmp11 = tmp8 & tmp10
490
+ tmp12 = tmp4 & tmp11
491
+ tmp13 = tmp1 | tmp12
492
+ tmp14 = ks8
493
+ tmp15 = tmp3 >= tmp14
494
+ tmp16 = (tmp3 % tmp14)
495
+ tmp17 = tl.full([1], 0, tl.int32)
496
+ tmp18 = tmp16 != tmp17
497
+ tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0
498
+ tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0
499
+ tmp21 = tmp19 != tmp20
500
+ tmp22 = tmp18 & tmp21
501
+ tmp23 = tmp16 + tmp14
502
+ tmp24 = tl.where(tmp22, tmp23, tmp16)
503
+ tmp25 = tmp24.to(tl.int64)
504
+ tmp26 = tmp25 < tmp7
505
+ tmp27 = tmp15 & tmp26
506
+ tmp28 = tmp3 - tmp2
507
+ tmp29 = (tmp28 % tmp14)
508
+ tmp30 = tmp29 != tmp17
509
+ tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0
510
+ tmp32 = tmp31 != tmp20
511
+ tmp33 = tmp30 & tmp32
512
+ tmp34 = tmp29 + tmp14
513
+ tmp35 = tl.where(tmp33, tmp34, tmp29)
514
+ tmp36 = tmp35 == tmp17
515
+ tmp37 = tmp27 & tmp36
516
+ tmp38 = tmp13 | tmp37
517
+ mask_mod_output = tmp38
518
+
519
+
520
+ # apply mask for partial masked block
521
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
522
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
523
+ if not PRESCALE_QK:
524
+ post_mod_scores *= RCP_LN2
525
+ p = tl.math.exp2(post_mod_scores - lse)
526
+ # Compute dP and dS.
527
+ # NB reversed order to since V is transposed
528
+ vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN)
529
+
530
+ dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION)
531
+ ds = p * (dp - Di[:, None])
532
+ # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
533
+ tmp39 = (ds)
534
+ grad_scores = tmp39
535
+
536
+
537
+ if not IS_DIVISIBLE:
538
+ grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0)
539
+
540
+ # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
541
+ if WRITE_DQ:
542
+ scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN)
543
+
544
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
545
+ ds = grad_scores
546
+
547
+ if not IS_FULL_BLOCKS:
548
+ # (grads) apply mask for partially unmasked block
549
+ ds = tl.where(mask_mod_output, ds, 0.0)
550
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
551
+ ds = ds.to(MATMUL_PRECISION)
552
+ # Compute dQ.
553
+ dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION)
554
+
555
+ return dq
556
+
557
+
558
+ @triton.jit
559
+ def bwd_dkdv_inner(
560
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8,
561
+ Q, DO, DELTA, LSE, # pointers
562
+ dk, dv, k, v,
563
+ off_z, off_hq, offs_n1, offs_m1,
564
+ stride_qm, stride_qd, stride_dom, stride_dod,
565
+ q_indices, sparse_q_num_blocks,
566
+ MATMUL_PRECISION,
567
+ IS_FULL_BLOCKS,
568
+ ):
569
+ PRESCALE_QK : tl.constexpr = False
570
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
571
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
572
+ WRITE_DQ : tl.constexpr = True
573
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
574
+ OUTPUT_MAX : tl.constexpr = False
575
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
576
+ IS_DIVISIBLE : tl.constexpr = False
577
+ SM_SCALE : tl.constexpr = 0.08838834764831843
578
+ GQA_SHARED_HEADS : tl.constexpr = 4
579
+ HAS_FULL_BLOCKS : tl.constexpr = True
580
+ QK_HEAD_DIM : tl.constexpr = 128
581
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
582
+ V_HEAD_DIM : tl.constexpr = 128
583
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
584
+ SAFE_HEAD_DIM : tl.constexpr = True
585
+ BLOCK_M1 : tl.constexpr = 64
586
+ BLOCK_N1 : tl.constexpr = 128
587
+ BLOCK_M2 : tl.constexpr = 128
588
+ BLOCK_N2 : tl.constexpr = 64
589
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
590
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
591
+ INDEX_DTYPE : tl.constexpr = tl.int32
592
+
593
+ SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
594
+ RCP_LN2: tl.constexpr = 1.44269504
595
+ Q_LEN = ks0
596
+ KV_LEN = ks1
597
+
598
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
599
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
600
+
601
+ qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd
602
+ do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod
603
+ # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
604
+ tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
605
+
606
+ # The minimum is needed to handle the case where we run with a super large
607
+ # SPARSE_BLOCK_SIZE (i.e. no block-mask!)
608
+ hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1))
609
+
610
+ for start_m in range(0, hi):
611
+ dk, dv = bwd_dkdv_block_mn(
612
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8,
613
+ dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
614
+ off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
615
+ stride_qm, stride_qd, stride_dom, stride_dod,
616
+ q_indices, sparse_q_num_blocks,
617
+ MATMUL_PRECISION, RCP_LN2,
618
+ IS_FULL_BLOCKS,
619
+ )
620
+ # Increment pointers.
621
+ offset = get_offset_for_next_block(
622
+ start_m, q_indices, sparse_q_num_blocks,
623
+ SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS
624
+ )
625
+
626
+ qT_ptrs += offset * stride_qm
627
+ do_ptrs += offset * stride_dom
628
+ offs_m1 += offset
629
+
630
+ return dk, dv
631
+
632
+
633
+ @triton.jit
634
+ def bwd_dkdv_block_mn(
635
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8,
636
+ dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
637
+ off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
638
+ stride_qm, stride_qd, stride_dom, stride_dod,
639
+ q_indices, sparse_q_num_blocks,
640
+ MATMUL_PRECISION, RCP_LN2,
641
+ IS_FULL_BLOCKS,
642
+ ):
643
+ PRESCALE_QK : tl.constexpr = False
644
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
645
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
646
+ WRITE_DQ : tl.constexpr = True
647
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
648
+ OUTPUT_MAX : tl.constexpr = False
649
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
650
+ IS_DIVISIBLE : tl.constexpr = False
651
+ SM_SCALE : tl.constexpr = 0.08838834764831843
652
+ GQA_SHARED_HEADS : tl.constexpr = 4
653
+ HAS_FULL_BLOCKS : tl.constexpr = True
654
+ QK_HEAD_DIM : tl.constexpr = 128
655
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
656
+ V_HEAD_DIM : tl.constexpr = 128
657
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
658
+ SAFE_HEAD_DIM : tl.constexpr = True
659
+ BLOCK_M1 : tl.constexpr = 64
660
+ BLOCK_N1 : tl.constexpr = 128
661
+ BLOCK_M2 : tl.constexpr = 128
662
+ BLOCK_N2 : tl.constexpr = 64
663
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
664
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
665
+ INDEX_DTYPE : tl.constexpr = tl.int32
666
+
667
+
668
+ # NB reversed order since Q is transposed
669
+ qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN)
670
+ # Load LSE before computing qk to reduce pipeline stall.
671
+ if IS_DIVISIBLE:
672
+ lse = tl.load(LSE + offs_m1)
673
+ else:
674
+ lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN)
675
+ lse = tl.where(lse == -float("inf"), 0.0, lse)
676
+ qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION)
677
+ if not PRESCALE_QK:
678
+ qkT *= SM_SCALE
679
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
680
+ m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None)
681
+ # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim
682
+ # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary
683
+ n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None)
684
+
685
+ pre_mod_scores = qkT
686
+ tmp40 = (qkT)
687
+ post_mod_scores = tmp40
688
+
689
+
690
+
691
+ if not IS_DIVISIBLE:
692
+ post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf"))
693
+
694
+ if not IS_FULL_BLOCKS:
695
+ tmp41 = tl.full([1], False, tl.int1)
696
+ tmp42 = (m)
697
+ tmp43 = (n)
698
+ tmp44 = tmp42 >= tmp43
699
+ tmp45 = tmp43.to(tl.int64)
700
+ tmp46 = (off_z)
701
+ tmp47 = tl.load(in_ptr16 + tmp46)
702
+ tmp48 = tmp45 < tmp47
703
+ tmp49 = tmp42.to(tl.int64)
704
+ tmp50 = tmp49 < tmp47
705
+ tmp51 = tmp48 & tmp50
706
+ tmp52 = tmp44 & tmp51
707
+ tmp53 = tmp41 | tmp52
708
+ tmp54 = ks8
709
+ tmp55 = tmp43 >= tmp54
710
+ tmp56 = (tmp43 % tmp54)
711
+ tmp57 = tl.full([1], 0, tl.int32)
712
+ tmp58 = tmp56 != tmp57
713
+ tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0
714
+ tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0
715
+ tmp61 = tmp59 != tmp60
716
+ tmp62 = tmp58 & tmp61
717
+ tmp63 = tmp56 + tmp54
718
+ tmp64 = tl.where(tmp62, tmp63, tmp56)
719
+ tmp65 = tmp64.to(tl.int64)
720
+ tmp66 = tmp65 < tmp47
721
+ tmp67 = tmp55 & tmp66
722
+ tmp68 = tmp43 - tmp42
723
+ tmp69 = (tmp68 % tmp54)
724
+ tmp70 = tmp69 != tmp57
725
+ tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0
726
+ tmp72 = tmp71 != tmp60
727
+ tmp73 = tmp70 & tmp72
728
+ tmp74 = tmp69 + tmp54
729
+ tmp75 = tl.where(tmp73, tmp74, tmp69)
730
+ tmp76 = tmp75 == tmp57
731
+ tmp77 = tmp67 & tmp76
732
+ tmp78 = tmp53 | tmp77
733
+ mask_mod_output = tmp78
734
+
735
+ # (grads) apply mask for fully masked block
736
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
737
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
738
+ if not PRESCALE_QK:
739
+ post_mod_scores *= RCP_LN2
740
+ pT = tl.math.exp2(post_mod_scores - lse[None, :])
741
+ do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
742
+ # Compute dV.
743
+ ppT = pT
744
+ dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION)
745
+ if IS_DIVISIBLE:
746
+ Di = tl.load(DELTA + offs_m1)
747
+ else:
748
+ Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN)
749
+ # Compute dP and dS.
750
+ dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION)
751
+ dsT = pT * (dpT - Di[None, :])
752
+ # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
753
+ tmp79 = (dsT)
754
+ grad_scores = tmp79
755
+
756
+
757
+
758
+ if not IS_DIVISIBLE:
759
+ grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0)
760
+
761
+ # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
762
+ if not WRITE_DQ:
763
+ idx_b = off_z
764
+ idx_h = off_hq
765
+ idx_m = m
766
+ idx_n = n
767
+ scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN)
768
+
769
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
770
+ dsT = grad_scores
771
+ if not IS_FULL_BLOCKS:
772
+ # (grads) apply mask for partially unmasked block
773
+ dsT = tl.where(mask_mod_output, dsT, 0.0)
774
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
775
+ dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION)
776
+
777
+ return dk, dv
778
+
779
+ # Utility triton funcs
780
+ @triton.jit
781
+ def get_offset_for_next_block(
782
+ loop_iter, col_indices, total_blocks,
783
+ SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
784
+ BLOCKS_ARE_CONTIGUOUS: tl.constexpr
785
+ ):
786
+ if BLOCKS_ARE_CONTIGUOUS:
787
+ return BLOCK
788
+ cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
789
+ cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
790
+ next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
791
+ needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
792
+ jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
793
+ offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
794
+ return offset
795
+
796
+ @triton.jit
797
+ def get_bounded_indices(indices, max_len=None):
798
+ return indices % max_len if max_len is not None else indices
799
+
800
+ @triton.jit
801
+ def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
802
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
803
+ return tl.load(block_ptr)
804
+ elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
805
+ return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
806
+ elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
807
+ return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
808
+ else:
809
+ return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
810
+
811
+ @triton.jit
812
+ def load_checked_2d(
813
+ ptr,
814
+ offs_m,
815
+ offs_n,
816
+ stride_m,
817
+ stride_n,
818
+ IS_DIVISIBLE_M: tl.constexpr,
819
+ IS_DIVISIBLE_N: tl.constexpr,
820
+ M_LEN: tl.constexpr,
821
+ N_LEN: tl.constexpr,
822
+ ):
823
+ # Calculate final pointer if strides are provided
824
+ if stride_m is not None and stride_n is not None:
825
+ ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
826
+
827
+ # Handle all masking cases
828
+ if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
829
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
830
+ elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
831
+ return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
832
+ elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
833
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
834
+ else: # Both divisible
835
+ return tl.load(ptr)
SpecForge-ext/cache/compiled_kernels/gv/cgvbha5mvyldninvrzu5qgbcoz6irvhuphtcgrde6mr733uggxnb.py ADDED
@@ -0,0 +1,543 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AOT ID: ['5_inference']
2
+ from ctypes import c_void_p, c_long, c_int
3
+ import torch
4
+ import math
5
+ import random
6
+ import os
7
+ import tempfile
8
+ from math import inf, nan
9
+ from cmath import nanj
10
+ from torch._inductor.hooks import run_intermediate_hooks
11
+ from torch._inductor.utils import maybe_profile
12
+ from torch._inductor.codegen.memory_planning import _align as align
13
+ from torch import device, empty_strided
14
+ from torch._inductor.async_compile import AsyncCompile
15
+ from torch._inductor.select_algorithm import extern_kernels
16
+ import triton
17
+ import triton.language as tl
18
+ from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
19
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
20
+
21
+ aten = torch.ops.aten
22
+ inductor_ops = torch.ops.inductor
23
+ _quantized = torch.ops._quantized
24
+ assert_size_stride = torch._C._dynamo.guards.assert_size_stride
25
+ assert_alignment = torch._C._dynamo.guards.assert_alignment
26
+ empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
27
+ empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
28
+ empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
29
+ empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
30
+ empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
31
+ reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
32
+ alloc_from_pool = torch.ops.inductor._alloc_from_pool
33
+ async_compile = AsyncCompile()
34
+ empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
35
+
36
+
37
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/fi/cfiplsvt2q6tbvsfjtg2dd47g7npdwtvk5m3lv4anjbxwgjigkj2.py
38
+ # Topologically Sorted Source Nodes: [result_1, m, causal_mask, n, b, index, lt, padding_mask, index_1, lt_1, and_2, suffix_mask, remainder, index_2, padding_mask_1, and_3, and_4, sub, remainder_1, diagnol_mask, result_2, batched_outputs_2, mask_2, mask_3, mask_block_sum], Original ATen: [aten.view, aten.arange, aten.ge, aten.index, aten.lt, aten.bitwise_and, aten.bitwise_or, aten.remainder, aten.sub, aten.eq, aten.permute, aten.sum]
39
+ # Source node to ATen node mapping:
40
+ # and_2 => bitwise_and_1
41
+ # and_3 => bitwise_and_2
42
+ # and_4 => bitwise_and_3, view_8
43
+ # b => iota
44
+ # batched_outputs_2 => view_9
45
+ # causal_mask => ge, view
46
+ # diagnol_mask => eq
47
+ # index => index
48
+ # index_1 => index_1
49
+ # index_2 => index_2
50
+ # lt => lt, view_1
51
+ # lt_1 => lt_1, view_2
52
+ # m => iota_2
53
+ # mask_2 => view_10
54
+ # mask_3 => permute
55
+ # mask_block_sum => sum_1
56
+ # n => iota_3
57
+ # padding_mask => bitwise_and, view_3, view_4
58
+ # padding_mask_1 => lt_2, view_6
59
+ # remainder => remainder
60
+ # remainder_1 => remainder_1
61
+ # result_1 => bitwise_or, full_default
62
+ # result_2 => bitwise_or_1
63
+ # sub => sub, view_7
64
+ # suffix_mask => ge_1
65
+ # Graph fragment:
66
+ # %arg0_1 : Tensor "i64[8][1]cuda:3" = PlaceHolder[target=arg0_1]
67
+ # %full_default : Tensor "b8[8, 1, 1][1, 1, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 1, 1], False), kwargs = {dtype: torch.bool, layout: torch.strided, device: cuda:3, pin_memory: False})
68
+ # %iota_2 : Tensor "i64[2048][1]cuda:3"[num_users=3] = call_function[target=torch.ops.prims.iota.default](args = (2048,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:3, requires_grad: False})
69
+ # %view : Tensor "i64[2048, 1][1, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%iota_2, [2048, 1]), kwargs = {})
70
+ # %iota_3 : Tensor "i64[2048][1]cuda:3"[num_users=5] = call_function[target=torch.ops.prims.iota.default](args = (2048,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:3, requires_grad: False})
71
+ # %ge : Tensor "b8[2048, 2048][2048, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.ge.Tensor](args = (%view, %iota_3), kwargs = {})
72
+ # %iota : Tensor "i64[8][1]cuda:3"[num_users=3] = call_function[target=torch.ops.prims.iota.default](args = (8,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:3, requires_grad: False})
73
+ # %index : Tensor "i64[8][1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg0_1, [%iota]), kwargs = {})
74
+ # %view_1 : Tensor "i64[8, 1][1, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%index, [8, 1]), kwargs = {})
75
+ # %lt : Tensor "b8[8, 2048][2048, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_3, %view_1), kwargs = {})
76
+ # %view_4 : Tensor "b8[8, 1, 2048][2048, 2048, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%lt, [8, 1, 2048]), kwargs = {})
77
+ # %index_1 : Tensor "i64[8][1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg0_1, [%iota]), kwargs = {})
78
+ # %view_2 : Tensor "i64[8, 1][1, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%index_1, [8, 1]), kwargs = {})
79
+ # %lt_1 : Tensor "b8[8, 2048][2048, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_2, %view_2), kwargs = {})
80
+ # %view_3 : Tensor "b8[8, 2048, 1][2048, 1, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%lt_1, [8, 2048, 1]), kwargs = {})
81
+ # %bitwise_and : Tensor "b8[8, 2048, 2048][4194304, 2048, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%view_4, %view_3), kwargs = {})
82
+ # %bitwise_and_1 : Tensor "b8[8, 2048, 2048][4194304, 2048, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%ge, %bitwise_and), kwargs = {})
83
+ # %bitwise_or : Tensor "b8[8, 2048, 2048][4194304, 2048, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.bitwise_or.Tensor](args = (%full_default, %bitwise_and_1), kwargs = {})
84
+ # %ge_1 : Tensor "b8[2048][1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.ge.Scalar](args = (%iota_3, 2048), kwargs = {})
85
+ # %remainder : Tensor "i64[2048][1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.remainder.Scalar](args = (%iota_3, 2048), kwargs = {})
86
+ # %index_2 : Tensor "i64[8][1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg0_1, [%iota]), kwargs = {})
87
+ # %view_6 : Tensor "i64[8, 1][1, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%index_2, [8, 1]), kwargs = {})
88
+ # %lt_2 : Tensor "b8[8, 2048][2048, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%remainder, %view_6), kwargs = {})
89
+ # %bitwise_and_2 : Tensor "b8[8, 2048][2048, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%ge_1, %lt_2), kwargs = {})
90
+ # %view_8 : Tensor "b8[8, 1, 2048][2048, 2048, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%bitwise_and_2, [8, 1, 2048]), kwargs = {})
91
+ # %view_7 : Tensor "i64[2048, 1][1, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%iota_2, [2048, 1]), kwargs = {})
92
+ # %sub : Tensor "i64[2048, 2048][2048, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%iota_3, %view_7), kwargs = {})
93
+ # %remainder_1 : Tensor "i64[2048, 2048][2048, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.remainder.Scalar](args = (%sub, 2048), kwargs = {})
94
+ # %eq : Tensor "b8[2048, 2048][2048, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%remainder_1, 0), kwargs = {})
95
+ # %bitwise_and_3 : Tensor "b8[8, 2048, 2048][4194304, 2048, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%view_8, %eq), kwargs = {})
96
+ # %bitwise_or_1 : Tensor "b8[8, 2048, 2048][4194304, 2048, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.bitwise_or.Tensor](args = (%bitwise_or, %bitwise_and_3), kwargs = {})
97
+ # %view_9 : Tensor "b8[8, 1, 2048, 2048][4194304, 4194304, 2048, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%bitwise_or_1, [8, 1, 2048, 2048]), kwargs = {})
98
+ # %view_10 : Tensor "b8[8, 1, 16, 128, 16, 128][4194304, 4194304, 262144, 2048, 128, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%expand, [8, 1, 16, 128, 16, 128]), kwargs = {})
99
+ # %permute : Tensor "b8[8, 1, 16, 16, 128, 128][4194304, 4194304, 262144, 128, 2048, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_10, [0, 1, 2, 4, 3, 5]), kwargs = {})
100
+ # %sum_1 : Tensor "i64[8, 1, 16, 16][256, 256, 16, 1]cuda:3"[num_users=3] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%permute, [-2, -1]), kwargs = {})
101
+ # return %sum_1
102
+ triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0 = async_compile.triton('triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0', '''
103
+ import triton
104
+ import triton.language as tl
105
+
106
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
107
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
108
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
109
+ triton_helpers.set_driver_to_gpu()
110
+
111
+ @triton_heuristics.reduction(
112
+ size_hints={'x': 2048, 'r0_': 16384},
113
+ reduction_hint=ReductionHint.INNER,
114
+ filename=__file__,
115
+ triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
116
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 32768, 'r0_': 0}}
117
+ )
118
+ @triton.jit
119
+ def triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
120
+ xnumel = 2048
121
+ r0_numel = 16384
122
+ rnumel = r0_numel
123
+ RBLOCK: tl.constexpr = R0_BLOCK
124
+ xoffset = tl.program_id(0) * XBLOCK
125
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
126
+ xmask = xindex < xnumel
127
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
128
+ rbase = r0_base
129
+ x1 = ((xindex // 16) % 16)
130
+ x0 = (xindex % 16)
131
+ x2 = xindex // 256
132
+ tmp3 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last')
133
+ _tmp29 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64)
134
+ x6 = xindex
135
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
136
+ r0_index = r0_offset + r0_base
137
+ r0_mask = r0_index < r0_numel
138
+ roffset = r0_offset
139
+ rindex = r0_index
140
+ r0_4 = r0_index // 128
141
+ r0_3 = (r0_index % 128)
142
+ tmp0 = r0_4 + 128*x1
143
+ tmp1 = r0_3 + 128*x0
144
+ tmp2 = tmp0 >= tmp1
145
+ tmp4 = tmp1 < tmp3
146
+ tmp5 = tmp0 < tmp3
147
+ tmp6 = tmp4 & tmp5
148
+ tmp7 = tmp2 & tmp6
149
+ tmp8 = tl.full([1, 1], False, tl.int1)
150
+ tmp9 = tmp8 | tmp7
151
+ tmp10 = tl.full([1, 1], 2048, tl.int64)
152
+ tmp11 = tmp1 >= tmp10
153
+ tmp12 = tmp11 & tmp4
154
+ tmp13 = r0_3 + ((-1)*r0_4) + ((-128)*x1) + 128*x0
155
+ tmp14 = (tmp13 % tmp10)
156
+ tmp15 = tl.full([1, 1], 0, tl.int32)
157
+ tmp16 = tmp14 != tmp15
158
+ tmp17 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0
159
+ tmp18 = (libdevice.signbit(tmp10) != 0) if (tmp10).dtype is tl.float32 else tmp10 < 0
160
+ tmp19 = tmp17 != tmp18
161
+ tmp20 = tmp16 & tmp19
162
+ tmp21 = tmp14 + tmp10
163
+ tmp22 = tl.where(tmp20, tmp21, tmp14)
164
+ tmp23 = tl.full([1, 1], 0, tl.int64)
165
+ tmp24 = tmp22 == tmp23
166
+ tmp25 = tmp12 & tmp24
167
+ tmp26 = tmp9 | tmp25
168
+ tmp27 = tmp26.to(tl.int64)
169
+ tmp28 = tl.broadcast_to(tmp27, [XBLOCK, R0_BLOCK])
170
+ tmp30 = _tmp29 + tmp28
171
+ _tmp29 = tl.where(r0_mask & xmask, tmp30, _tmp29)
172
+ tmp29 = tl.sum(_tmp29, 1)[:, None]
173
+ tl.store(out_ptr0 + (x6), tmp29, xmask)
174
+ ''', device_str='cuda')
175
+
176
+
177
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ks/ckske6cm4vgoewu6hpzmhdk7yxnddtnqlrbts7nwodsrty3grim2.py
178
+ # Topologically Sorted Source Nodes: [dense_mask_4], Original ATen: [aten.new_zeros]
179
+ # Source node to ATen node mapping:
180
+ # dense_mask_4 => full_default_4
181
+ # Graph fragment:
182
+ # %full_default_4 : Tensor "i32[8, 1, 16, 17][272, 272, 17, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 1, 16, 17], 0), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:3, pin_memory: False})
183
+ # return %index_put_1
184
+ triton_poi_fused_new_zeros_1 = async_compile.triton('triton_poi_fused_new_zeros_1', '''
185
+ import triton
186
+ import triton.language as tl
187
+
188
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
189
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
190
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
191
+ triton_helpers.set_driver_to_gpu()
192
+
193
+ @triton_heuristics.pointwise(
194
+ size_hints={'x': 4096},
195
+ filename=__file__,
196
+ triton_meta={'signature': {'out_ptr0': '*i32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]},
197
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_new_zeros_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 17408}},
198
+ min_elem_per_thread=0
199
+ )
200
+ @triton.jit
201
+ def triton_poi_fused_new_zeros_1(out_ptr0, xnumel, XBLOCK : tl.constexpr):
202
+ xnumel = 2176
203
+ xoffset = tl.program_id(0) * XBLOCK
204
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
205
+ xmask = xindex < xnumel
206
+ x0 = xindex
207
+ tmp0 = tl.full([1], 0, tl.int32)
208
+ tl.store(out_ptr0 + (x0), tmp0, xmask)
209
+ ''', device_str='cuda')
210
+
211
+
212
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ub/cubczabrf2ptryq2athnmru2byipbgarioa3462puxv2jwv6vm4c.py
213
+ # Topologically Sorted Source Nodes: [gt, lt_3, partial_blocks, partial_blocks_1, dense_mask, col_indices, full_blocks, full_blocks_1, dense_mask_1, col_indices_1, dense_mask_2, setitem, arange_4, row_indices, col_range, num_blocks_in_row, child_3, unsqueeze_1, index_mask, child_4, valid_indices, dense_mask_4, setitem_1, arange_6, row_indices_1, col_range_1, num_blocks_in_row_1, child_7, unsqueeze_3, index_mask_1, child_8, valid_indices_1], Original ATen: [aten.gt, aten.lt, aten.bitwise_and, aten._to_copy, aten.sort, aten.eq, aten.new_zeros, aten.arange, aten.unsqueeze, aten.sum, aten.scalar_tensor, aten.where, aten.view, aten.index_put]
214
+ # Source node to ATen node mapping:
215
+ # arange_4 => iota_4
216
+ # arange_6 => iota_8
217
+ # child_3 => convert_element_type_3
218
+ # child_4 => convert_element_type_4
219
+ # child_7 => convert_element_type_6
220
+ # child_8 => convert_element_type_7
221
+ # col_indices => sort
222
+ # col_indices_1 => sort_1
223
+ # col_range => iota_5
224
+ # col_range_1 => iota_9
225
+ # dense_mask => convert_element_type_2
226
+ # dense_mask_1 => convert_element_type_5
227
+ # dense_mask_2 => full_default_1
228
+ # dense_mask_4 => full_default_4
229
+ # full_blocks => eq_1
230
+ # full_blocks_1 => convert_element_type_1
231
+ # gt => gt
232
+ # index_mask => lt_4
233
+ # index_mask_1 => lt_5
234
+ # lt_3 => lt_3
235
+ # num_blocks_in_row => sum_2
236
+ # num_blocks_in_row_1 => sum_3
237
+ # partial_blocks => bitwise_and_4
238
+ # partial_blocks_1 => convert_element_type
239
+ # row_indices => unsqueeze
240
+ # row_indices_1 => unsqueeze_7
241
+ # setitem => full_default_3, index_put, iota_6, iota_7, unsqueeze_2, unsqueeze_3, unsqueeze_4, unsqueeze_5, unsqueeze_6
242
+ # setitem_1 => full_default_6, index_put_1, iota_10, iota_11, unsqueeze_10, unsqueeze_11, unsqueeze_12, unsqueeze_13, unsqueeze_9
243
+ # unsqueeze_1 => unsqueeze_1
244
+ # unsqueeze_3 => unsqueeze_8
245
+ # valid_indices => full_default_2, where
246
+ # valid_indices_1 => full_default_5, where_1
247
+ # Graph fragment:
248
+ # %sum_1 : Tensor "i64[8, 1, 16, 16][256, 2048, 16, 1]cuda:3" = PlaceHolder[target=sum_1]
249
+ # %sum_2 : Tensor "i64[8, 1, 16][16, 128, 1]cuda:3" = PlaceHolder[target=sum_2]
250
+ # %sum_3 : Tensor "i64[8, 1, 16][16, 128, 1]cuda:3" = PlaceHolder[target=sum_3]
251
+ # %buf2 : Tensor "i16[8, 1, 16, 16][256, 2048, 16, 1]cuda:3" = PlaceHolder[target=buf2]
252
+ # %convert_element_type_3 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:3" = PlaceHolder[target=convert_element_type_3]
253
+ # %convert_element_type_4 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:3" = PlaceHolder[target=convert_element_type_4]
254
+ # %index_put : Tensor "i32[8, 1, 16, 17][272, 272, 17, 1]cuda:3" = PlaceHolder[target=index_put]
255
+ # %buf4 : Tensor "i16[8, 1, 16, 16][256, 2048, 16, 1]cuda:3" = PlaceHolder[target=buf4]
256
+ # %convert_element_type_6 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:3" = PlaceHolder[target=convert_element_type_6]
257
+ # %convert_element_type_7 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:3" = PlaceHolder[target=convert_element_type_7]
258
+ # %index_put_1 : Tensor "i32[8, 1, 16, 17][272, 272, 17, 1]cuda:3" = PlaceHolder[target=index_put_1]
259
+ # %gt : Tensor "b8[8, 1, 16, 16][256, 256, 16, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {})
260
+ # %lt_3 : Tensor "b8[8, 1, 16, 16][256, 256, 16, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.lt.Scalar](args = (%sum_1, 16384), kwargs = {})
261
+ # %bitwise_and_4 : Tensor "b8[8, 1, 16, 16][256, 256, 16, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%gt, %lt_3), kwargs = {})
262
+ # %convert_element_type : Tensor "i8[8, 1, 16, 16][256, 256, 16, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%bitwise_and_4, torch.int8), kwargs = {})
263
+ # %convert_element_type_2 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:3"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%convert_element_type, torch.int32), kwargs = {})
264
+ # %sort : [num_users=1] = call_function[target=torch.ops.aten.sort.stable](args = (%convert_element_type_2,), kwargs = {stable: True, descending: True})
265
+ # %eq_1 : Tensor "b8[8, 1, 16, 16][256, 256, 16, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%sum_1, 16384), kwargs = {})
266
+ # %convert_element_type_1 : Tensor "i8[8, 1, 16, 16][256, 256, 16, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%eq_1, torch.int8), kwargs = {})
267
+ # %convert_element_type_5 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:3"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%convert_element_type_1, torch.int32), kwargs = {})
268
+ # %sort_1 : [num_users=1] = call_function[target=torch.ops.aten.sort.stable](args = (%convert_element_type_5,), kwargs = {stable: True, descending: True})
269
+ # %full_default_1 : Tensor "i32[8, 1, 16, 17][272, 272, 17, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 1, 16, 17], 0), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:3, pin_memory: False})
270
+ # %iota_7 : Tensor "i64[8][1]cuda:3"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (8,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:3, requires_grad: False})
271
+ # %unsqueeze_4 : Tensor "i64[8, 1][1, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_7, -1), kwargs = {})
272
+ # %unsqueeze_5 : Tensor "i64[8, 1, 1][1, 1, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_4, -1), kwargs = {})
273
+ # %unsqueeze_6 : Tensor "i64[8, 1, 1, 1][1, 1, 1, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_5, -1), kwargs = {})
274
+ # %iota_6 : Tensor "i64[1][1]cuda:3"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (1,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:3, requires_grad: False})
275
+ # %unsqueeze_2 : Tensor "i64[1, 1][1, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_6, -1), kwargs = {})
276
+ # %unsqueeze_3 : Tensor "i64[1, 1, 1][1, 1, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_2, -1), kwargs = {})
277
+ # %iota_4 : Tensor "i32[16][1]cuda:3"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (16,), kwargs = {start: 0, step: 1, dtype: torch.int32, device: cuda:3, requires_grad: False})
278
+ # %unsqueeze : Tensor "i32[16, 1][1, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_4, -1), kwargs = {})
279
+ # %iota_5 : Tensor "i32[16][1]cuda:3"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (16,), kwargs = {start: 0, step: 1, dtype: torch.int32, device: cuda:3, requires_grad: False})
280
+ # %sum_2 : Tensor "i64[8, 1, 16][16, 16, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%convert_element_type_2, [-1]), kwargs = {})
281
+ # %convert_element_type_3 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:3"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_2, torch.int32), kwargs = {})
282
+ # %unsqueeze_1 : Tensor "i32[8, 1, 16, 1][16, 16, 1, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_3, 3), kwargs = {})
283
+ # %lt_4 : Tensor "b8[8, 1, 16, 16][256, 256, 16, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_5, %unsqueeze_1), kwargs = {})
284
+ # %convert_element_type_4 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:3"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_1, torch.int32), kwargs = {})
285
+ # %full_default_2 : Tensor "i32[][]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([], 16), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:3, pin_memory: False})
286
+ # %where : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%lt_4, %convert_element_type_4, %full_default_2), kwargs = {})
287
+ # %full_default_3 : Tensor "i32[8, 1, 1, 1][1, 1, 1, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 1, 1, 1], 1), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:3, pin_memory: False})
288
+ # %index_put : Tensor "i32[8, 1, 16, 17][272, 272, 17, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.index_put_.default](args = (%full_default_1, [%unsqueeze_6, %unsqueeze_3, %unsqueeze, %where], %full_default_3), kwargs = {})
289
+ # %full_default_4 : Tensor "i32[8, 1, 16, 17][272, 272, 17, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 1, 16, 17], 0), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:3, pin_memory: False})
290
+ # %iota_11 : Tensor "i64[8][1]cuda:3"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (8,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:3, requires_grad: False})
291
+ # %unsqueeze_11 : Tensor "i64[8, 1][1, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_11, -1), kwargs = {})
292
+ # %unsqueeze_12 : Tensor "i64[8, 1, 1][1, 1, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_11, -1), kwargs = {})
293
+ # %unsqueeze_13 : Tensor "i64[8, 1, 1, 1][1, 1, 1, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_12, -1), kwargs = {})
294
+ # %iota_10 : Tensor "i64[1][1]cuda:3"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (1,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:3, requires_grad: False})
295
+ # %unsqueeze_9 : Tensor "i64[1, 1][1, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_10, -1), kwargs = {})
296
+ # %unsqueeze_10 : Tensor "i64[1, 1, 1][1, 1, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_9, -1), kwargs = {})
297
+ # %iota_8 : Tensor "i32[16][1]cuda:3"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (16,), kwargs = {start: 0, step: 1, dtype: torch.int32, device: cuda:3, requires_grad: False})
298
+ # %unsqueeze_7 : Tensor "i32[16, 1][1, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_8, -1), kwargs = {})
299
+ # %iota_9 : Tensor "i32[16][1]cuda:3"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (16,), kwargs = {start: 0, step: 1, dtype: torch.int32, device: cuda:3, requires_grad: False})
300
+ # %sum_3 : Tensor "i64[8, 1, 16][16, 16, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%convert_element_type_5, [-1]), kwargs = {})
301
+ # %convert_element_type_6 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:3"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_3, torch.int32), kwargs = {})
302
+ # %unsqueeze_8 : Tensor "i32[8, 1, 16, 1][16, 16, 1, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_6, 3), kwargs = {})
303
+ # %lt_5 : Tensor "b8[8, 1, 16, 16][256, 256, 16, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_9, %unsqueeze_8), kwargs = {})
304
+ # %convert_element_type_7 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:3"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_3, torch.int32), kwargs = {})
305
+ # %full_default_5 : Tensor "i32[][]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([], 16), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:3, pin_memory: False})
306
+ # %where_1 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%lt_5, %convert_element_type_7, %full_default_5), kwargs = {})
307
+ # %full_default_6 : Tensor "i32[8, 1, 1, 1][1, 1, 1, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 1, 1, 1], 1), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:3, pin_memory: False})
308
+ # %index_put_1 : Tensor "i32[8, 1, 16, 17][272, 272, 17, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.index_put_.default](args = (%full_default_4, [%unsqueeze_13, %unsqueeze_10, %unsqueeze_7, %where_1], %full_default_6), kwargs = {})
309
+ # return %buf2,%buf4,%sum_2,%sum_3,%convert_element_type_3,%convert_element_type_6,%convert_element_type_4,%buf9,%convert_element_type_7,%buf16
310
+ triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2 = async_compile.triton('triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2', '''
311
+ import triton
312
+ import triton.language as tl
313
+
314
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
315
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
316
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
317
+ triton_helpers.set_driver_to_gpu()
318
+
319
+ @triton_heuristics.persistent_reduction(
320
+ size_hints={'x': 128, 'r0_': 16},
321
+ reduction_hint=ReductionHint.DEFAULT,
322
+ filename=__file__,
323
+ triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr4': '*i32', 'out_ptr5': '*i32', 'out_ptr6': '*i32', 'out_ptr7': '*i32', 'out_ptr8': '*i32', 'out_ptr9': '*i32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]},
324
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2', 'mutated_arg_names': ['out_ptr7', 'out_ptr9'], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
325
+ )
326
+ @triton.jit
327
+ def triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2(in_ptr0, out_ptr4, out_ptr5, out_ptr6, out_ptr7, out_ptr8, out_ptr9, xnumel, r0_numel, XBLOCK : tl.constexpr):
328
+ xnumel = 128
329
+ r0_numel = 16
330
+ R0_BLOCK: tl.constexpr = 16
331
+ rnumel = r0_numel
332
+ RBLOCK: tl.constexpr = R0_BLOCK
333
+ xoffset = tl.program_id(0) * XBLOCK
334
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
335
+ xmask = xindex < xnumel
336
+ r0_index = tl.arange(0, R0_BLOCK)[None, :]
337
+ r0_offset = 0
338
+ r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
339
+ roffset = r0_offset
340
+ rindex = r0_index
341
+ r0_1 = r0_index
342
+ x0 = xindex
343
+ tmp0 = tl.load(in_ptr0 + (r0_1 + 16*x0), xmask, other=0.0)
344
+ tmp1 = tl.full([1, 1], 0, tl.int64)
345
+ tmp2 = tmp0 > tmp1
346
+ tmp3 = tl.full([1, 1], 16384, tl.int64)
347
+ tmp4 = tmp0 < tmp3
348
+ tmp5 = tmp2 & tmp4
349
+ tmp6 = tmp5.to(tl.int8)
350
+ tmp7 = tmp6.to(tl.int32)
351
+ tmp8 = r0_1
352
+ tmp9 = tmp8.to(tl.int16)
353
+ tmp10 = tl.broadcast_to(tmp7, [XBLOCK, R0_BLOCK])
354
+ tmp11 = tl.broadcast_to(tmp9, [XBLOCK, R0_BLOCK])
355
+ tmp12, tmp13, = triton_helpers.sort_with_index(tmp10, tmp11, None, 1, stable=True, descending=True)
356
+ tmp14 = tmp0 == tmp3
357
+ tmp15 = tmp14.to(tl.int8)
358
+ tmp16 = tmp15.to(tl.int32)
359
+ tmp17 = tl.broadcast_to(tmp16, [XBLOCK, R0_BLOCK])
360
+ tmp18, tmp19, = triton_helpers.sort_with_index(tmp17, tmp11, None, 1, stable=True, descending=True)
361
+ tmp20 = tmp7.to(tl.int64)
362
+ tmp21 = tl.broadcast_to(tmp20, [XBLOCK, R0_BLOCK])
363
+ tmp23 = tl.where(xmask, tmp21, 0)
364
+ tmp24 = tl.sum(tmp23, 1)[:, None].to(tl.int64)
365
+ tmp25 = tmp16.to(tl.int64)
366
+ tmp26 = tl.broadcast_to(tmp25, [XBLOCK, R0_BLOCK])
367
+ tmp28 = tl.where(xmask, tmp26, 0)
368
+ tmp29 = tl.sum(tmp28, 1)[:, None].to(tl.int64)
369
+ tmp30 = tmp24.to(tl.int32)
370
+ tmp31 = tmp29.to(tl.int32)
371
+ tmp32 = tmp13.to(tl.int64)
372
+ tmp33 = tmp32.to(tl.int32)
373
+ tmp34 = tmp8 < tmp30
374
+ tmp35 = tl.full([1, 1], 16, tl.int32)
375
+ tmp36 = tl.where(tmp34, tmp33, tmp35)
376
+ tmp37 = tl.full([XBLOCK, R0_BLOCK], 17, tl.int32)
377
+ tmp38 = tmp36 + tmp37
378
+ tmp39 = tmp36 < 0
379
+ tmp40 = tl.where(tmp39, tmp38, tmp36)
380
+ tl.device_assert(((0 <= tmp40) & (tmp40 < 17)) | ~(xmask), "index out of bounds: 0 <= tmp40 < 17")
381
+ tmp42 = tl.full([1, 1], 1, tl.int32)
382
+ tmp43 = tmp19.to(tl.int64)
383
+ tmp44 = tmp43.to(tl.int32)
384
+ tmp45 = tmp8 < tmp31
385
+ tmp46 = tl.where(tmp45, tmp44, tmp35)
386
+ tmp47 = tmp46 + tmp37
387
+ tmp48 = tmp46 < 0
388
+ tmp49 = tl.where(tmp48, tmp47, tmp46)
389
+ tl.device_assert(((0 <= tmp49) & (tmp49 < 17)) | ~(xmask), "index out of bounds: 0 <= tmp49 < 17")
390
+ tl.store(out_ptr4 + (x0), tmp30, xmask)
391
+ tl.store(out_ptr5 + (x0), tmp31, xmask)
392
+ tl.store(out_ptr6 + (r0_1 + 16*x0), tmp33, xmask)
393
+ tl.store(out_ptr7 + (tl.broadcast_to(tmp40 + 17*x0, [XBLOCK, R0_BLOCK])), tmp42, xmask)
394
+ tl.store(out_ptr8 + (r0_1 + 16*x0), tmp44, xmask)
395
+ tl.store(out_ptr9 + (tl.broadcast_to(tmp49 + 17*x0, [XBLOCK, R0_BLOCK])), tmp42, xmask)
396
+ ''', device_str='cuda')
397
+
398
+
399
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/j2/cj2n7wy6fqxxmwdo7cbwmboubeyyxlahapaguf5qhoapiz75gjze.py
400
+ # Topologically Sorted Source Nodes: [batched_outputs_3, transpose, col_indices_2, q_indices, num_blocks_in_row_2, q_num_blocks], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sort, aten._to_copy, aten.sum]
401
+ # Source node to ATen node mapping:
402
+ # batched_outputs_3 => clone_4, slice_2
403
+ # col_indices_2 => sort_2
404
+ # num_blocks_in_row_2 => sum_4
405
+ # q_indices => clone_6, convert_element_type_9
406
+ # q_num_blocks => convert_element_type_8
407
+ # transpose => permute_1
408
+ # Graph fragment:
409
+ # %buf9 : Tensor "i32[8, 1, 16, 17][272, 272, 17, 1]cuda:3" = PlaceHolder[target=buf9]
410
+ # %buf11 : Tensor "i16[8, 1, 16, 16][256, 2048, 16, 1]cuda:3" = PlaceHolder[target=buf11]
411
+ # %sum_4 : Tensor "i64[8, 1, 16][16, 128, 1]cuda:3" = PlaceHolder[target=sum_4]
412
+ # %slice_2 : Tensor "i32[8, 1, 16, 16][272, 272, 17, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%index_put, 3, 0, 16), kwargs = {})
413
+ # %clone_4 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%slice_2,), kwargs = {memory_format: torch.contiguous_format})
414
+ # %permute_1 : Tensor "i32[8, 1, 16, 16][256, 256, 1, 16]cuda:3"[num_users=2] = call_function[target=torch.ops.aten.permute.default](args = (%clone_4, [0, 1, 3, 2]), kwargs = {})
415
+ # %sort_2 : [num_users=1] = call_function[target=torch.ops.aten.sort.stable](args = (%permute_1,), kwargs = {stable: True, descending: True})
416
+ # %convert_element_type_9 : Tensor "i32[8, 1, 16, 16][256, 256, 1, 16]cuda:3"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_5, torch.int32), kwargs = {})
417
+ # %clone_6 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%convert_element_type_9,), kwargs = {memory_format: torch.contiguous_format})
418
+ # %sum_4 : Tensor "i64[8, 1, 16][16, 16, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%permute_1, [-1]), kwargs = {})
419
+ # %convert_element_type_8 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_4, torch.int32), kwargs = {})
420
+ # return %buf11,%sum_4,%clone_6,%convert_element_type_8
421
+ triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3 = async_compile.triton('triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3', '''
422
+ import triton
423
+ import triton.language as tl
424
+
425
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
426
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
427
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
428
+ triton_helpers.set_driver_to_gpu()
429
+
430
+ @triton_heuristics.persistent_reduction(
431
+ size_hints={'x': 128, 'r0_': 16},
432
+ reduction_hint=ReductionHint.DEFAULT,
433
+ filename=__file__,
434
+ triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr2': '*i32', 'out_ptr3': '*i32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
435
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 1024, 'r0_': 16384}}
436
+ )
437
+ @triton.jit
438
+ def triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3(in_ptr0, out_ptr2, out_ptr3, xnumel, r0_numel, XBLOCK : tl.constexpr):
439
+ xnumel = 128
440
+ r0_numel = 16
441
+ R0_BLOCK: tl.constexpr = 16
442
+ rnumel = r0_numel
443
+ RBLOCK: tl.constexpr = R0_BLOCK
444
+ xoffset = tl.program_id(0) * XBLOCK
445
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
446
+ xmask = xindex < xnumel
447
+ r0_index = tl.arange(0, R0_BLOCK)[None, :]
448
+ r0_offset = 0
449
+ r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
450
+ roffset = r0_offset
451
+ rindex = r0_index
452
+ r0_2 = r0_index
453
+ x0 = (xindex % 16)
454
+ x1 = xindex // 16
455
+ x3 = xindex
456
+ tmp0 = tl.load(in_ptr0 + (x0 + 17*r0_2 + 272*x1), xmask, other=0.0)
457
+ tmp1 = r0_2
458
+ tmp2 = tmp1.to(tl.int16)
459
+ tmp3 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
460
+ tmp4 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK])
461
+ tmp5, tmp6, = triton_helpers.sort_with_index(tmp3, tmp4, None, 1, stable=True, descending=True)
462
+ tmp7 = tmp0.to(tl.int64)
463
+ tmp8 = tl.broadcast_to(tmp7, [XBLOCK, R0_BLOCK])
464
+ tmp10 = tl.where(xmask, tmp8, 0)
465
+ tmp11 = tl.sum(tmp10, 1)[:, None].to(tl.int64)
466
+ tmp12 = tmp6.to(tl.int64)
467
+ tmp13 = tmp12.to(tl.int32)
468
+ tmp14 = tmp11.to(tl.int32)
469
+ tl.store(out_ptr2 + (r0_2 + 16*x3), tmp13, xmask)
470
+ tl.store(out_ptr3 + (x3), tmp14, xmask)
471
+ ''', device_str='cuda')
472
+
473
+
474
+ async_compile.wait(globals())
475
+ del async_compile
476
+
477
+ class Runner:
478
+ def __init__(self, partitions):
479
+ self.partitions = partitions
480
+
481
+ def recursively_apply_fns(self, fns):
482
+ new_callables = []
483
+ for fn, c in zip(fns, self.partitions):
484
+ new_callables.append(fn(c))
485
+ self.partitions = new_callables
486
+
487
+ def call(self, args):
488
+ arg0_1, = args
489
+ args.clear()
490
+ assert_size_stride(arg0_1, (8, ), (1, ))
491
+ with torch.cuda._DeviceGuard(3):
492
+ torch.cuda.set_device(3)
493
+ buf0 = empty_strided_cuda((8, 1, 16, 16), (256, 2048, 16, 1), torch.int64)
494
+ # Topologically Sorted Source Nodes: [result_1, m, causal_mask, n, b, index, lt, padding_mask, index_1, lt_1, and_2, suffix_mask, remainder, index_2, padding_mask_1, and_3, and_4, sub, remainder_1, diagnol_mask, result_2, batched_outputs_2, mask_2, mask_3, mask_block_sum], Original ATen: [aten.view, aten.arange, aten.ge, aten.index, aten.lt, aten.bitwise_and, aten.bitwise_or, aten.remainder, aten.sub, aten.eq, aten.permute, aten.sum]
495
+ stream3 = get_raw_stream(3)
496
+ triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0.run(arg0_1, buf0, 2048, 16384, stream=stream3)
497
+ del arg0_1
498
+ buf15 = empty_strided_cuda((8, 1, 16, 17), (272, 272, 17, 1), torch.int32)
499
+ # Topologically Sorted Source Nodes: [dense_mask_4], Original ATen: [aten.new_zeros]
500
+ stream3 = get_raw_stream(3)
501
+ triton_poi_fused_new_zeros_1.run(buf15, 2176, stream=stream3)
502
+ buf8 = empty_strided_cuda((8, 1, 16, 17), (272, 272, 17, 1), torch.int32)
503
+ # Topologically Sorted Source Nodes: [dense_mask_2], Original ATen: [aten.new_zeros]
504
+ stream3 = get_raw_stream(3)
505
+ triton_poi_fused_new_zeros_1.run(buf8, 2176, stream=stream3)
506
+ buf6 = empty_strided_cuda((8, 1, 16), (16, 16, 1), torch.int32)
507
+ buf13 = empty_strided_cuda((8, 1, 16), (16, 16, 1), torch.int32)
508
+ buf7 = empty_strided_cuda((8, 1, 16, 16), (256, 256, 16, 1), torch.int32)
509
+ buf14 = empty_strided_cuda((8, 1, 16, 16), (256, 256, 16, 1), torch.int32)
510
+ # Topologically Sorted Source Nodes: [gt, lt_3, partial_blocks, partial_blocks_1, dense_mask, col_indices, full_blocks, full_blocks_1, dense_mask_1, col_indices_1, dense_mask_2, setitem, arange_4, row_indices, col_range, num_blocks_in_row, child_3, unsqueeze_1, index_mask, child_4, valid_indices, dense_mask_4, setitem_1, arange_6, row_indices_1, col_range_1, num_blocks_in_row_1, child_7, unsqueeze_3, index_mask_1, child_8, valid_indices_1], Original ATen: [aten.gt, aten.lt, aten.bitwise_and, aten._to_copy, aten.sort, aten.eq, aten.new_zeros, aten.arange, aten.unsqueeze, aten.sum, aten.scalar_tensor, aten.where, aten.view, aten.index_put]
511
+ stream3 = get_raw_stream(3)
512
+ triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.run(buf0, buf6, buf13, buf7, buf8, buf14, buf15, 128, 16, stream=stream3)
513
+ del buf0
514
+ buf22 = empty_strided_cuda((8, 1, 16, 16), (256, 256, 16, 1), torch.int32)
515
+ buf24 = empty_strided_cuda((8, 1, 16), (16, 16, 1), torch.int32)
516
+ # Topologically Sorted Source Nodes: [batched_outputs_3, transpose, col_indices_2, q_indices, num_blocks_in_row_2, q_num_blocks], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sort, aten._to_copy, aten.sum]
517
+ stream3 = get_raw_stream(3)
518
+ triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3.run(buf8, buf22, buf24, 128, 16, stream=stream3)
519
+ del buf8
520
+ buf19 = empty_strided_cuda((8, 1, 16, 16), (256, 256, 16, 1), torch.int32)
521
+ buf21 = empty_strided_cuda((8, 1, 16), (16, 16, 1), torch.int32)
522
+ # Topologically Sorted Source Nodes: [batched_outputs_5, transpose_1, col_indices_3, full_q_indices, num_blocks_in_row_3, full_q_num_blocks], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sort, aten._to_copy, aten.sum]
523
+ stream3 = get_raw_stream(3)
524
+ triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3.run(buf15, buf19, buf21, 128, 16, stream=stream3)
525
+ del buf15
526
+ return (buf19, buf21, buf22, buf24, buf14, buf13, buf7, buf6, )
527
+
528
+ runner = Runner(partitions=[])
529
+ call = runner.call
530
+ recursively_apply_fns = runner.recursively_apply_fns
531
+
532
+
533
+ def benchmark_compiled_module(times=10, repeat=10):
534
+ from torch._dynamo.testing import rand_strided
535
+ from torch._inductor.utils import print_performance
536
+ arg0_1 = rand_strided((8, ), (1, ), device='cuda:3', dtype=torch.int64)
537
+ fn = lambda: call([arg0_1])
538
+ return print_performance(fn, times=times, repeat=repeat)
539
+
540
+
541
+ if __name__ == "__main__":
542
+ from torch._inductor.wrapper_benchmark import compiled_module_main
543
+ compiled_module_main('None', benchmark_compiled_module)
SpecForge-ext/cache/compiled_kernels/gy/94796f3e1399aa6e798adba6b896031b3152400abd45f5ee80e2ec3df79f0b97.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 512, "num_warps": 8, "num_stages": 1, "configs_hash": "3ca5c3e34d35093f3c9ab2829a9faeebad5e61c4ca13d5ed6053d7b71ce60d5a", "found_by_coordesc": false, "time_taken_ms": 22, "triton_cache_hash": "XRR2QXTZQK4DSBTDJUTNXO6FEFXI2IIRKSC5GYSBWLTL56SKI4WA"}