Lekr0 commited on
Commit
7cd1cbc
·
verified ·
1 Parent(s): c7d7aff

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. SpecForge-ext/cache/compiled_kernels/2a/c2aenxafaj3vioqyzq7mx27etpwqzasypu2acikotkgg3rec7mlw.py +47 -0
  2. SpecForge-ext/cache/compiled_kernels/2c/c2cmsqbgkrofzfikzrnehvhp4wxhze4bly4ct5edlg3syiny626e.py +43 -0
  3. SpecForge-ext/cache/compiled_kernels/2k/c2kz55grrshpc3qkvg6jesbu63ts5wlhkwtjukm7zkcvhkilgn76.py +52 -0
  4. SpecForge-ext/cache/compiled_kernels/2v/c2vabblrjzyryauc2jram5kwgwvjexq53bdwxugagjegc2xvufuy.py +44 -0
  5. SpecForge-ext/cache/compiled_kernels/2v/c2vbm66z3map72ysgiduadjtps3nnrhjldngw5bzue3cm5xo44w5.py +835 -0
  6. SpecForge-ext/cache/compiled_kernels/35/c35huqp6ngzh67kt32kuxoqpghc32fstv4zogcouzabdxxwta3sl.py +27 -0
  7. SpecForge-ext/cache/compiled_kernels/35/db497030eed19cbbd19ee623329ec09ad9ad496274b305a5f803696a7ce87fc1.best_config +1 -0
  8. SpecForge-ext/cache/compiled_kernels/45/c45n6n4rv3f3q66fpyq53ugyel2jmywhufx7ogpqvuyls4hiicz2.py +1065 -0
  9. SpecForge-ext/cache/compiled_kernels/45/c45wcuv4sn2ie6lowss5cksehnjgehlinebmvyopum4so5p257dk.py +835 -0
  10. SpecForge-ext/cache/compiled_kernels/4h/9e00c3cbd4f3ffea506c2d972effa4f5d1a03b1819fbd2068ce6d04ad21a37d7.best_config +1 -0
  11. SpecForge-ext/cache/compiled_kernels/4h/c4hrpftpfto2n4yelfxmq5tawsfst2z5xq7othxvdoymqaudsvcw.py +56 -0
  12. SpecForge-ext/cache/compiled_kernels/4k/30a0e09dbdf44769796e9e261da2a9dcbfc798ae7811e19f9adc033f960f3fae.best_config +1 -0
  13. SpecForge-ext/cache/compiled_kernels/4k/c4kzcehfveyvvtlnmx5jh5naezqnmtz2ubxuawsucb27r43j5yfa.py +49 -0
  14. SpecForge-ext/cache/compiled_kernels/4m/c4mv34wib446qhr7sd5yhgc4mdneb7isnb6uitnbwvdgrbpgyf2s.py +552 -0
  15. SpecForge-ext/cache/compiled_kernels/4u/c4uf4o6eypfpqr4isgii4opqr5i3brobwecljte7sqvztk2kyafz.py +552 -0
  16. SpecForge-ext/cache/compiled_kernels/4u/c4uhrh7gjsy72in52pmmkpoiwetwjbked3nkrbcotbo4sj5bq7bi.py +835 -0
  17. SpecForge-ext/cache/compiled_kernels/4x/c4xykt7eysbenti5r55drq4w7k6c7fih4ifrou2alyqcn6r5enon.py +835 -0
  18. SpecForge-ext/cache/compiled_kernels/54/c5464ptly4n22voq77yo3wrltmxhbase2ojnypkgcpcxg6js4oty.py +46 -0
  19. SpecForge-ext/cache/compiled_kernels/54/c54p5bozrk7z3jkhpl6meytxfu7bz7ojmkijrdgczbq55oalwpgl.py +552 -0
  20. SpecForge-ext/cache/compiled_kernels/5h/c5h6tol66uk77tfumu3xd25ecbr6kkxkqgk3zbmjpk4tc6sikmjb.py +543 -0
  21. SpecForge-ext/cache/compiled_kernels/5p/c5pbkg5eq64emuv25ukki7a5dxvn2p2sh6jeiwb6b54tbidps5w7.py +552 -0
  22. SpecForge-ext/cache/compiled_kernels/5s/c5siycmobmba5rqczjfbtd45di6el6qnpizugzs3hsg4jzkcqnpk.py +161 -0
  23. SpecForge-ext/cache/compiled_kernels/5u/235c5fbee66a14cc3d65896905ec816ec90c51ba6594c4a627960306977eb07c.best_config +1 -0
  24. SpecForge-ext/cache/compiled_kernels/6b/c6beknosybos5d54llineldguuueh3kpjlkiuzm4pkorx7g6mjh6.py +45 -0
  25. SpecForge-ext/cache/compiled_kernels/6b/c6bpf3ctcqs5wvcac26go3fcp5hdc2pxduwgba2cnxt52xqmp6mq.py +334 -0
  26. SpecForge-ext/cache/compiled_kernels/6j/b801eb968d13baeef00c09ffebb7c203c75661545f70c7ec4ed906e946ad8a67.best_config +1 -0
  27. SpecForge-ext/cache/compiled_kernels/6j/c6jx5fvfijye7zqqg42xonpcdfuwatv7bizrwompd5o3dua57uju.py +24 -0
  28. SpecForge-ext/cache/compiled_kernels/6o/c6obqatzdeyb7elxstetxuvmlhbvwph6buxkixqs4flvdn2x6vgl.py +835 -0
  29. SpecForge-ext/cache/compiled_kernels/7g/59ff39d5526de7bb833fbd386ca3ce564bdaf6828f559a423e599b5ad90d0456.best_config +1 -0
  30. SpecForge-ext/cache/compiled_kernels/7m/c7mmadjna7dltm72lxvsoktdadnw2jtxufsj2eoflefh2r5jo4gq.py +24 -0
  31. SpecForge-ext/cache/compiled_kernels/7m/e130479b4d145e755b390ab3b709dd817d1548c0596f91391e7581de8609a9eb.best_config +1 -0
  32. SpecForge-ext/cache/compiled_kernels/7o/c7oiol3zozs5oktlpjhg3lu46rhbgu3bqq6yibefmn2imo6bua5k.py +48 -0
  33. SpecForge-ext/cache/compiled_kernels/7z/c7z2jbjub3aupgnechol65vkvi5ruwpylzosdbqvscdyxmreb3jy.py +86 -0
  34. SpecForge-ext/cache/compiled_kernels/7z/c7z6kbhlhnd55iz3suxpzcfjhjv7p7i2zelu2nitjoegrwczbdyf.py +52 -0
  35. SpecForge-ext/cache/compiled_kernels/7z/de291d239bdb6c33244f90904700e0423d0a8026bdcf04c4cb1f87b0edee041b.best_config +1 -0
  36. SpecForge-ext/cache/compiled_kernels/aa/caa67m6yhgzsw5semsgkn3vvui6pjb2e2mxtfb5xyoo3c5qle6ao.py +320 -0
  37. SpecForge-ext/cache/compiled_kernels/aa/caabkjzbaqm7hrv3ypoalyjx45pdt7jezorxxk75d4cahg2knncu.py +89 -0
  38. SpecForge-ext/cache/compiled_kernels/af/cafe3dsuelcloemwu5jdikp7lqano5qxv7iayhtm5xgji2xvr4k6.py +47 -0
  39. SpecForge-ext/cache/compiled_kernels/ai/caivmpnbt7ve3qybkm6k756igdxn3ykevul35fdg4vvgknrmprqo.py +66 -0
  40. SpecForge-ext/cache/compiled_kernels/ai/f2f38be4dfdf6b1c14c068f88a04203cd9a67c3fc07629f341d6212e60d2f52e.best_config +1 -0
  41. SpecForge-ext/cache/compiled_kernels/al/25feb68bb70a2d653884ed092be99a324d74e7c4fa2b0800c70b0c5cede23a82.best_config +1 -0
  42. SpecForge-ext/cache/compiled_kernels/al/cal2r4tfyw6gic3ggqyud3nufnajx6xau2koieoitx6zg4wsiozm.py +56 -0
  43. SpecForge-ext/cache/compiled_kernels/aq/caqqpjwqelw7hv6k6nwpxjuod3tfnwg62cypxwyuozfme2ykuybp.py +307 -0
  44. SpecForge-ext/cache/compiled_kernels/aq/caqvrlb25w5an4txp3dstxcj6tqlcc4mprakf75e5sbtbuzd254g.py +711 -0
  45. SpecForge-ext/cache/compiled_kernels/at/2dfb5ffb77d217b8298333b84d6362971879c20614915aac57601c1f150ac07b.best_config +1 -0
  46. SpecForge-ext/cache/compiled_kernels/at/cat6f3b7vbc3opxxrqwtgyrnap7msqfa5gw45bly56fm7xfzsng7.py +27 -0
  47. SpecForge-ext/cache/compiled_kernels/at/catnwworbo47zz5uux2qx6gtvq5zrkdmzm5qpt64msmr3cjlnoz5.py +675 -0
  48. SpecForge-ext/cache/compiled_kernels/av/cavp7xan77tfr7qytfkp6sjrgkd6hvruiaqfzkeibtl5rtagscng.py +99 -0
  49. SpecForge-ext/cache/compiled_kernels/bd/cbdpymknkquuerovirx6corahubfs5khfhys2add2b3c2zkuvlup.py +835 -0
  50. SpecForge-ext/cache/compiled_kernels/bi/8786fd641e91216a3bc7781055fbc9277e1637f9f319eaed8124e438ba94886f.best_config +1 -0
SpecForge-ext/cache/compiled_kernels/2a/c2aenxafaj3vioqyzq7mx27etpwqzasypu2acikotkgg3rec7mlw.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.reduction(
11
+ size_hints={'x': 4096, 'r0_': 32768},
12
+ reduction_hint=ReductionHint.INNER,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*i64', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(1,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
16
+ )
17
+ @triton.jit
18
+ def triton_red_fused_argmax_1(in_ptr0, out_ptr0, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
19
+ r0_numel = 32000
20
+ rnumel = r0_numel
21
+ RBLOCK: tl.constexpr = R0_BLOCK
22
+ xoffset = tl.program_id(0) * XBLOCK
23
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
24
+ xmask = xindex < xnumel
25
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
26
+ rbase = r0_base
27
+ x0 = (xindex % ks0)
28
+ x1 = xindex // ks0
29
+ _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32)
30
+ _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32)
31
+ x3 = xindex
32
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
33
+ r0_index = r0_offset + r0_base
34
+ r0_mask = r0_index < r0_numel
35
+ roffset = r0_offset
36
+ rindex = r0_index
37
+ r0_2 = r0_index
38
+ tmp0 = tl.load(in_ptr0 + (r0_2 + 32000*x0 + ks1*x1), r0_mask & xmask, eviction_policy='evict_first', other=0.0)
39
+ tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
40
+ _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index(
41
+ _tmp2, _tmp2_index, tmp1, rindex
42
+ )
43
+ _tmp2 = tl.where(r0_mask & xmask, _tmp2_next, _tmp2)
44
+ _tmp2_index = tl.where(r0_mask & xmask, _tmp2_index_next, _tmp2_index)
45
+ tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1)
46
+ tmp2 = tmp2_idx[:, None]
47
+ tl.store(out_ptr0 + (x3), tmp2, xmask)
SpecForge-ext/cache/compiled_kernels/2c/c2cmsqbgkrofzfikzrnehvhp4wxhze4bly4ct5edlg3syiny626e.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.reduction(
11
+ size_hints={'x': 32, 'r0_': 16},
12
+ reduction_hint=ReductionHint.INNER,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr1': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_sum_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
16
+ )
17
+ @triton.jit
18
+ def triton_red_fused__to_copy_sum_2(in_ptr0, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
19
+ rnumel = r0_numel
20
+ RBLOCK: tl.constexpr = R0_BLOCK
21
+ xoffset = tl.program_id(0) * XBLOCK
22
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
23
+ xmask = xindex < xnumel
24
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
25
+ rbase = r0_base
26
+ x0 = xindex
27
+ _tmp3 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64)
28
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
29
+ r0_index = r0_offset + r0_base
30
+ r0_mask = r0_index < r0_numel
31
+ roffset = r0_offset
32
+ rindex = r0_index
33
+ r0_1 = r0_index
34
+ tmp0 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0)
35
+ tmp1 = tmp0.to(tl.int64)
36
+ tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK])
37
+ tmp4 = _tmp3 + tmp2
38
+ _tmp3 = tl.where(r0_mask & xmask, tmp4, _tmp3)
39
+ tmp3 = tl.sum(_tmp3, 1)[:, None]
40
+ x2 = (xindex % ks1)
41
+ x3 = xindex // ks1
42
+ tmp5 = tmp3.to(tl.int32)
43
+ tl.store(out_ptr1 + (x2 + x3*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1)))), tmp5, xmask)
SpecForge-ext/cache/compiled_kernels/2k/c2kz55grrshpc3qkvg6jesbu63ts5wlhkwtjukm7zkcvhkilgn76.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.reduction(
11
+ size_hints={'x': 131072, 'r0_': 128},
12
+ reduction_hint=ReductionHint.OUTER,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_mul_sum_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
16
+ )
17
+ @triton.jit
18
+ def triton_red_fused__to_copy_mul_sum_0(in_ptr0, in_ptr1, in_ptr2, out_ptr0, ks0, ks1, ks2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
19
+ rnumel = r0_numel
20
+ RBLOCK: tl.constexpr = R0_BLOCK
21
+ xoffset = tl.program_id(0) * XBLOCK
22
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
23
+ xmask = xindex < xnumel
24
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
25
+ rbase = r0_base
26
+ x1 = xindex // ks0
27
+ x0 = (xindex % ks0)
28
+ _tmp13 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
29
+ x3 = xindex
30
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
31
+ r0_index = r0_offset + r0_base
32
+ r0_mask = r0_index < r0_numel
33
+ roffset = r0_offset
34
+ rindex = r0_index
35
+ r0_2 = r0_index
36
+ tmp0 = r0_2 + x1*((31 + ks1*ks2) // 32)
37
+ tmp1 = ks1*ks2
38
+ tmp2 = tmp0 < tmp1
39
+ tmp3 = tl.load(in_ptr0 + (x0 + ks0*(((r0_2 + x1*((31 + ks1*ks2) // 32)) % (ks1*ks2)))), r0_mask & tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
40
+ tmp4 = tl.load(in_ptr1 + (x0 + ks0*(((r0_2 + x1*((31 + ks1*ks2) // 32)) % (ks1*ks2)))), r0_mask & tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
41
+ tmp5 = tmp4.to(tl.float32)
42
+ tmp6 = tl.load(in_ptr2 + (((r0_2 + x1*((31 + ks1*ks2) // 32)) % (ks1*ks2))), r0_mask & tmp2 & xmask, eviction_policy='evict_last', other=0.0)
43
+ tmp7 = tmp5 * tmp6
44
+ tmp8 = tmp7.to(tl.float32)
45
+ tmp9 = tmp3 * tmp8
46
+ tmp10 = tl.full(tmp9.shape, 0, tmp9.dtype)
47
+ tmp11 = tl.where(tmp2, tmp9, tmp10)
48
+ tmp12 = tl.broadcast_to(tmp11, [XBLOCK, R0_BLOCK])
49
+ tmp14 = _tmp13 + tmp12
50
+ _tmp13 = tl.where(r0_mask & xmask, tmp14, _tmp13)
51
+ tmp13 = tl.sum(_tmp13, 1)[:, None]
52
+ tl.store(out_ptr0 + (x3), tmp13, xmask)
SpecForge-ext/cache/compiled_kernels/2v/c2vabblrjzyryauc2jram5kwgwvjexq53bdwxugagjegc2xvufuy.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.reduction(
11
+ size_hints={'x': 1, 'r0_': 4096},
12
+ reduction_hint=ReductionHint.INNER,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i64', 'in_ptr2': '*i64', 'out_ptr0': '*i64', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {'xnumel': 1}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_eq_mul_squeeze_sum_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
16
+ )
17
+ @triton.jit
18
+ def triton_red_fused_eq_mul_squeeze_sum_2(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
19
+ xnumel = 1
20
+ rnumel = r0_numel
21
+ RBLOCK: tl.constexpr = R0_BLOCK
22
+ xoffset = tl.program_id(0) * XBLOCK
23
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
24
+ xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
25
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
26
+ rbase = r0_base
27
+ _tmp7 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64)
28
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
29
+ r0_index = r0_offset + r0_base
30
+ r0_mask = r0_index < r0_numel
31
+ roffset = r0_offset
32
+ rindex = r0_index
33
+ r0_0 = r0_index
34
+ tmp0 = tl.load(in_ptr0 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0)
35
+ tmp1 = tl.load(in_ptr1 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0)
36
+ tmp4 = tl.load(in_ptr2 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0)
37
+ tmp2 = tmp0 == tmp1
38
+ tmp3 = tmp2.to(tl.int64)
39
+ tmp5 = tmp3 * tmp4
40
+ tmp6 = tl.broadcast_to(tmp5, [XBLOCK, R0_BLOCK])
41
+ tmp8 = _tmp7 + tmp6
42
+ _tmp7 = tl.where(r0_mask, tmp8, _tmp7)
43
+ tmp7 = tl.sum(_tmp7, 1)[:, None]
44
+ tl.store(out_ptr0 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp7, None)
SpecForge-ext/cache/compiled_kernels/2v/c2vbm66z3map72ysgiduadjtps3nnrhjldngw5bzue3cm5xo44w5.py ADDED
@@ -0,0 +1,835 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+
9
+ @triton_heuristics.template(
10
+
11
+ num_stages=3,
12
+ num_warps=8,
13
+ triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32', 'ks6': 'i32', 'ks7': 'i32', 'ks8': 'i32'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'kernel_name': 'triton_tem_fused_zeros_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
15
+
16
+ )
17
+ @triton.jit
18
+ def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8):
19
+ PRESCALE_QK : tl.constexpr = False
20
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
21
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
22
+ WRITE_DQ : tl.constexpr = True
23
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
24
+ OUTPUT_MAX : tl.constexpr = False
25
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
26
+ IS_DIVISIBLE : tl.constexpr = False
27
+ SM_SCALE : tl.constexpr = 0.08838834764831843
28
+ GQA_SHARED_HEADS : tl.constexpr = 4
29
+ HAS_FULL_BLOCKS : tl.constexpr = True
30
+ QK_HEAD_DIM : tl.constexpr = 128
31
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
32
+ V_HEAD_DIM : tl.constexpr = 128
33
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
34
+ SAFE_HEAD_DIM : tl.constexpr = True
35
+ BLOCK_M1 : tl.constexpr = 64
36
+ BLOCK_N1 : tl.constexpr = 128
37
+ BLOCK_M2 : tl.constexpr = 128
38
+ BLOCK_N2 : tl.constexpr = 64
39
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
40
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
41
+ INDEX_DTYPE : tl.constexpr = tl.int32
42
+ Q = arg_Q
43
+ K = arg_K
44
+ V = arg_V
45
+ LSE = arg_LSE
46
+ DELTA = arg_DELTA
47
+ DO = arg_DO
48
+ DQ = arg_DQ
49
+ DV = arg_DV
50
+ KV_NUM_BLKS = arg_KV_NUM_BLKS
51
+ KV_IDX = arg_KV_IDX
52
+ Q_NUM_BLKS = arg_Q_NUM_BLKS
53
+ Q_IDX = arg_Q_IDX
54
+ FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
55
+ FULL_KV_IDX = arg_FULL_KV_IDX
56
+ FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS
57
+ FULL_Q_IDX = arg_FULL_Q_IDX
58
+
59
+ # Sub notation for this kernel:
60
+ #
61
+ # Q: Query, K: Key, V: Value
62
+ # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype)
63
+ # DELTA: Precomputed sum(OUT*DO, axis=-1)
64
+ # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value
65
+ # DK: Derivative of Key, is the written to via the store_output call due to some limitations with
66
+ # inductor codegen
67
+ # M: Number of queries, N: Number of keys/values
68
+ # QK_HEAD_DIM: The dimension of the query and key embeddings
69
+ # V_HEAD_DIM: The dimension of the value embeddings
70
+ # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim
71
+ # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
72
+ # (Modifiable) Performance tuning options
73
+ # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block.
74
+ # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V.
75
+ # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q.
76
+ # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block.
77
+ #
78
+ # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
79
+ # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
80
+ # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
81
+ # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query.
82
+ # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query.
83
+ # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
84
+ # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
85
+ # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query.
86
+ # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query.
87
+
88
+ # The below are kernel options that can be applied for certain score_mods,
89
+ # or involve a numerics vs. perf tradeoff
90
+ # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
91
+ # about 20% more numerical error, but slightly faster.
92
+
93
+ # Define strides of inputs
94
+ stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1
95
+ stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128*ks1, 128, 1
96
+ stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128*ks1, 128, 1
97
+ stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1
98
+
99
+ stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1
100
+ stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128*ks1, 128, 1
101
+
102
+ ZQ = 8
103
+ HQ = 32
104
+ HKV = 8
105
+ Q_LEN = ks0
106
+ ZKV = 8
107
+ KV_LEN = ks1
108
+
109
+ MATMUL_PRECISION = Q.dtype.element_ty
110
+
111
+ pid = tl.program_id(0).to(INDEX_DTYPE)
112
+ NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1)
113
+ NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2)
114
+
115
+ off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx
116
+ off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx
117
+ off_zkv = off_zq % ZKV # kv batch idx
118
+
119
+ SPARSE_Z = 8
120
+ SPARSE_HQ = 1
121
+
122
+ sparse_idx_z = off_zq % SPARSE_Z
123
+
124
+ k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64)
125
+ v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64)
126
+ # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
127
+ # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
128
+ dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64)
129
+
130
+ # offset K, V, DV pointers for batch/kv-head
131
+ K += k_adj
132
+ V += v_adj
133
+ DV += dv_adj
134
+
135
+ RCP_LN2 = 1.44269504
136
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
137
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
138
+
139
+ if pid >= NUM_KV_BLOCKS:
140
+ off_pid = pid - NUM_KV_BLOCKS
141
+ # THIS BLOCK DOES DQ
142
+ SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2)
143
+ SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
144
+ off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS
145
+ start_m2_block = off_pid % NUM_Q_BLOCKS
146
+ off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE
147
+ stride_kv_num_blks_h = ks2
148
+ stride_kv_idx_h = ks3*ks4
149
+ stride_kv_idx_m = ks4
150
+
151
+ sparse_idx_hq2 = off_hq2 % SPARSE_HQ
152
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2
153
+
154
+ sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask
155
+ sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950
156
+
157
+ # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
158
+ q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64)
159
+ do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64)
160
+ dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64)
161
+ off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64)
162
+
163
+ Q2 = Q + q_adj2
164
+ DO2 = DO + do_adj2
165
+ # TODO: This does not work if DQ is not the same layout as Q (for example,
166
+ # if Q is broadcasted)
167
+ DQ2 = DQ + dq_adj2
168
+ LSE2 = LSE + off_chz2
169
+ DELTA2 = DELTA + off_chz2
170
+
171
+ # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32)
172
+ dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
173
+
174
+ start_m2 = start_m2_block * BLOCK_M2
175
+ offs_m2 = start_m2 + tl.arange(0, BLOCK_M2)
176
+
177
+ # load Q and do: they stay in SRAM throughout the inner loop.
178
+ q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
179
+ do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
180
+
181
+ if PRESCALE_QK:
182
+ q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
183
+
184
+ if IS_DIVISIBLE:
185
+ Di = tl.load(DELTA2 + offs_m2)
186
+ lse = tl.load(LSE2 + offs_m2)
187
+ else:
188
+ Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN)
189
+ lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN)
190
+ lse = tl.where(lse == -float("inf"), 0.0, lse)
191
+ lse = lse[:, None]
192
+
193
+ # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
194
+ # KV_IDX and KV_NUM_BLKS are always contiguous.
195
+ kv_indices = KV_IDX + sparse_kv_idx_offset
196
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
197
+ sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
198
+
199
+ offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
200
+ dq = bwd_dq_inner(
201
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8,
202
+ K, V,
203
+ dq, q, do, Di, lse,
204
+ off_zq, off_hq2, offs_m2, offs_n2,
205
+ stride_kn, stride_kd, stride_vn, stride_vd,
206
+ kv_indices, sparse_kv_num_blocks,
207
+ MATMUL_PRECISION,
208
+ IS_FULL_BLOCKS=False,
209
+ )
210
+
211
+ if HAS_FULL_BLOCKS:
212
+ # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
213
+ # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
214
+ kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
215
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
216
+ sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
217
+
218
+ offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
219
+ dq = bwd_dq_inner(
220
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8,
221
+ K, V,
222
+ dq, q, do, Di, lse,
223
+ off_zq, off_hq2, offs_m2, offs_n2,
224
+ stride_kn, stride_kd, stride_vn, stride_vd,
225
+ kv_indices, sparse_kv_num_blocks,
226
+ MATMUL_PRECISION,
227
+ IS_FULL_BLOCKS=True,
228
+ )
229
+
230
+ # Write back dQ.
231
+ dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd
232
+ dq *= SM_SCALE
233
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
234
+ tl.store(dq_ptrs, dq)
235
+ else:
236
+ tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM))
237
+ else:
238
+ # THIS BLOCK DOES DK & DV
239
+ SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
240
+ SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1)
241
+
242
+ pid_mask = pid // SPARSE_KV_MULTIPLE
243
+
244
+ stride_q_num_blks_h = ks5
245
+ stride_q_idx_h = ks6*ks7
246
+ stride_q_idx_n = ks6
247
+
248
+
249
+ dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
250
+ dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
251
+
252
+ start_n1 = pid * BLOCK_N1
253
+ offs_n1 = start_n1 + tl.arange(0, BLOCK_N1)
254
+
255
+ # load K and V: they stay in SRAM throughout the inner loop.
256
+ k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
257
+ v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
258
+
259
+ if PRESCALE_QK:
260
+ k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
261
+
262
+ for off_g in range(0, GQA_SHARED_HEADS):
263
+ off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g
264
+
265
+ # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
266
+ q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64)
267
+ do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64)
268
+ dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64)
269
+ off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64)
270
+
271
+ Q1 = Q + q_adj1
272
+ DO1 = DO + do_adj1
273
+ # TODO: This does not work if DQ is not the same layout as Q (for example,
274
+ # if Q is broadcasted)
275
+ LSE1 = LSE + off_chz1
276
+ DELTA1 = DELTA + off_chz1
277
+
278
+ sparse_idx_hq1 = off_hq1 % SPARSE_HQ
279
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1
280
+
281
+ sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask
282
+ sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950
283
+
284
+ # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
285
+ # Q_IDX and Q_NUM_BLKS are always contiguous.
286
+ q_indices = Q_IDX + sparse_q_idx_offset
287
+ q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
288
+ sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset)
289
+
290
+ offs_m1 = q_start + tl.arange(0, BLOCK_M1)
291
+ dk, dv = bwd_dkdv_inner(
292
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8,
293
+ Q1, DO1, DELTA1, LSE1,
294
+ dk, dv, k, v,
295
+ off_zq, off_hq1, offs_n1, offs_m1,
296
+ stride_qm, stride_qd, stride_dom, stride_dod,
297
+ q_indices, sparse_q_num_blocks,
298
+ MATMUL_PRECISION,
299
+ IS_FULL_BLOCKS=False,
300
+ )
301
+
302
+
303
+ if HAS_FULL_BLOCKS:
304
+ # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
305
+ # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous.
306
+ q_indices = FULL_Q_IDX + sparse_q_idx_offset
307
+ q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
308
+ sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset)
309
+
310
+ offs_m1 = q_start + tl.arange(0, BLOCK_M1)
311
+ dk, dv = bwd_dkdv_inner(
312
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8,
313
+ Q1, DO1, DELTA1, LSE1,
314
+ dk, dv, k, v,
315
+ off_zq, off_hq1, offs_n1, offs_m1,
316
+ stride_qm, stride_qd, stride_dom, stride_dod,
317
+ q_indices, sparse_q_num_blocks,
318
+ MATMUL_PRECISION,
319
+ IS_FULL_BLOCKS=True,
320
+ )
321
+
322
+ # Write back dV and dK.
323
+ dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd
324
+
325
+ index_n = offs_n1[:, None]
326
+ index_k = offs_k[None, :]
327
+ index_v = offs_v[None, :]
328
+
329
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
330
+ tl.store(dv_ptrs, dv)
331
+ else:
332
+ tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM))
333
+
334
+ dk *= SM_SCALE
335
+
336
+ if SAFE_HEAD_DIM:
337
+ mask = index_n < KV_LEN
338
+ else:
339
+ mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM)
340
+
341
+ # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
342
+ # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
343
+ tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED])
344
+ xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1
345
+ tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask)
346
+
347
+ @triton.jit
348
+ def bwd_dq_inner(
349
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8,
350
+ K, V, # pointers
351
+ dq, q, do, Di, lse,
352
+ off_z, off_hq, offs_m2, offs_n2,
353
+ stride_kn, stride_kd, stride_vn, stride_vd,
354
+ kv_indices, sparse_kv_num_blocks,
355
+ MATMUL_PRECISION,
356
+ IS_FULL_BLOCKS,
357
+ ):
358
+ PRESCALE_QK : tl.constexpr = False
359
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
360
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
361
+ WRITE_DQ : tl.constexpr = True
362
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
363
+ OUTPUT_MAX : tl.constexpr = False
364
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
365
+ IS_DIVISIBLE : tl.constexpr = False
366
+ SM_SCALE : tl.constexpr = 0.08838834764831843
367
+ GQA_SHARED_HEADS : tl.constexpr = 4
368
+ HAS_FULL_BLOCKS : tl.constexpr = True
369
+ QK_HEAD_DIM : tl.constexpr = 128
370
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
371
+ V_HEAD_DIM : tl.constexpr = 128
372
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
373
+ SAFE_HEAD_DIM : tl.constexpr = True
374
+ BLOCK_M1 : tl.constexpr = 64
375
+ BLOCK_N1 : tl.constexpr = 128
376
+ BLOCK_M2 : tl.constexpr = 128
377
+ BLOCK_N2 : tl.constexpr = 64
378
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
379
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
380
+ INDEX_DTYPE : tl.constexpr = tl.int32
381
+
382
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
383
+ RCP_LN2: tl.constexpr = 1.44269504
384
+ Q_LEN = ks0
385
+ KV_LEN = ks1
386
+
387
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
388
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
389
+
390
+ kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd
391
+ vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd
392
+ # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
393
+ tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
394
+
395
+ hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1))
396
+
397
+ for start_n in range(0, hi):
398
+ dq = bwd_dq_block_mn(
399
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8,
400
+ dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
401
+ off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
402
+ stride_kn, stride_kd, stride_vn, stride_vd,
403
+ kv_indices, sparse_kv_num_blocks,
404
+ MATMUL_PRECISION, RCP_LN2,
405
+ IS_FULL_BLOCKS,
406
+ )
407
+
408
+ # Increment pointers.
409
+ offset = get_offset_for_next_block(
410
+ start_n, kv_indices, sparse_kv_num_blocks,
411
+ SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS
412
+ )
413
+
414
+ kT_ptrs += offset * stride_kn
415
+ vT_ptrs += offset * stride_vn
416
+
417
+ offs_n2 += offset
418
+
419
+ return dq
420
+
421
+
422
+ @triton.jit
423
+ def bwd_dq_block_mn(
424
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8,
425
+ dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
426
+ off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
427
+ stride_kn, stride_kd, stride_vn, stride_vd,
428
+ kv_indices, sparse_kv_num_blocks,
429
+ MATMUL_PRECISION, RCP_LN2,
430
+ IS_FULL_BLOCKS,
431
+ ):
432
+ PRESCALE_QK : tl.constexpr = False
433
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
434
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
435
+ WRITE_DQ : tl.constexpr = True
436
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
437
+ OUTPUT_MAX : tl.constexpr = False
438
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
439
+ IS_DIVISIBLE : tl.constexpr = False
440
+ SM_SCALE : tl.constexpr = 0.08838834764831843
441
+ GQA_SHARED_HEADS : tl.constexpr = 4
442
+ HAS_FULL_BLOCKS : tl.constexpr = True
443
+ QK_HEAD_DIM : tl.constexpr = 128
444
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
445
+ V_HEAD_DIM : tl.constexpr = 128
446
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
447
+ SAFE_HEAD_DIM : tl.constexpr = True
448
+ BLOCK_M1 : tl.constexpr = 64
449
+ BLOCK_N1 : tl.constexpr = 128
450
+ BLOCK_M2 : tl.constexpr = 128
451
+ BLOCK_N2 : tl.constexpr = 64
452
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
453
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
454
+ INDEX_DTYPE : tl.constexpr = tl.int32
455
+
456
+
457
+ # NB reversed order to since K is transposed
458
+ kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN)
459
+ qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION)
460
+ if not PRESCALE_QK:
461
+ qk *= SM_SCALE
462
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
463
+ pre_mod_scores = qk
464
+ n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None)
465
+ # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim
466
+ # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary
467
+ m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None)
468
+
469
+ tmp0 = (qk)
470
+ post_mod_scores = tmp0
471
+
472
+
473
+
474
+
475
+ if not IS_DIVISIBLE:
476
+ post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf"))
477
+
478
+ if not IS_FULL_BLOCKS:
479
+ tmp1 = tl.full([1], False, tl.int1)
480
+ tmp2 = (m)
481
+ tmp3 = (n)
482
+ tmp4 = tmp2 >= tmp3
483
+ tmp5 = tmp3.to(tl.int64)
484
+ tmp6 = (off_z)
485
+ tmp7 = tl.load(in_ptr16 + tmp6)
486
+ tmp8 = tmp5 < tmp7
487
+ tmp9 = tmp2.to(tl.int64)
488
+ tmp10 = tmp9 < tmp7
489
+ tmp11 = tmp8 & tmp10
490
+ tmp12 = tmp4 & tmp11
491
+ tmp13 = tmp1 | tmp12
492
+ tmp14 = ks8
493
+ tmp15 = tmp3 >= tmp14
494
+ tmp16 = (tmp3 % tmp14)
495
+ tmp17 = tl.full([1], 0, tl.int32)
496
+ tmp18 = tmp16 != tmp17
497
+ tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0
498
+ tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0
499
+ tmp21 = tmp19 != tmp20
500
+ tmp22 = tmp18 & tmp21
501
+ tmp23 = tmp16 + tmp14
502
+ tmp24 = tl.where(tmp22, tmp23, tmp16)
503
+ tmp25 = tmp24.to(tl.int64)
504
+ tmp26 = tmp25 < tmp7
505
+ tmp27 = tmp15 & tmp26
506
+ tmp28 = tmp3 - tmp2
507
+ tmp29 = (tmp28 % tmp14)
508
+ tmp30 = tmp29 != tmp17
509
+ tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0
510
+ tmp32 = tmp31 != tmp20
511
+ tmp33 = tmp30 & tmp32
512
+ tmp34 = tmp29 + tmp14
513
+ tmp35 = tl.where(tmp33, tmp34, tmp29)
514
+ tmp36 = tmp35 == tmp17
515
+ tmp37 = tmp27 & tmp36
516
+ tmp38 = tmp13 | tmp37
517
+ mask_mod_output = tmp38
518
+
519
+
520
+ # apply mask for partial masked block
521
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
522
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
523
+ if not PRESCALE_QK:
524
+ post_mod_scores *= RCP_LN2
525
+ p = tl.math.exp2(post_mod_scores - lse)
526
+ # Compute dP and dS.
527
+ # NB reversed order to since V is transposed
528
+ vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN)
529
+
530
+ dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION)
531
+ ds = p * (dp - Di[:, None])
532
+ # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
533
+ tmp39 = (ds)
534
+ grad_scores = tmp39
535
+
536
+
537
+ if not IS_DIVISIBLE:
538
+ grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0)
539
+
540
+ # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
541
+ if WRITE_DQ:
542
+ scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN)
543
+
544
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
545
+ ds = grad_scores
546
+
547
+ if not IS_FULL_BLOCKS:
548
+ # (grads) apply mask for partially unmasked block
549
+ ds = tl.where(mask_mod_output, ds, 0.0)
550
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
551
+ ds = ds.to(MATMUL_PRECISION)
552
+ # Compute dQ.
553
+ dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION)
554
+
555
+ return dq
556
+
557
+
558
+ @triton.jit
559
+ def bwd_dkdv_inner(
560
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8,
561
+ Q, DO, DELTA, LSE, # pointers
562
+ dk, dv, k, v,
563
+ off_z, off_hq, offs_n1, offs_m1,
564
+ stride_qm, stride_qd, stride_dom, stride_dod,
565
+ q_indices, sparse_q_num_blocks,
566
+ MATMUL_PRECISION,
567
+ IS_FULL_BLOCKS,
568
+ ):
569
+ PRESCALE_QK : tl.constexpr = False
570
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
571
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
572
+ WRITE_DQ : tl.constexpr = True
573
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
574
+ OUTPUT_MAX : tl.constexpr = False
575
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
576
+ IS_DIVISIBLE : tl.constexpr = False
577
+ SM_SCALE : tl.constexpr = 0.08838834764831843
578
+ GQA_SHARED_HEADS : tl.constexpr = 4
579
+ HAS_FULL_BLOCKS : tl.constexpr = True
580
+ QK_HEAD_DIM : tl.constexpr = 128
581
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
582
+ V_HEAD_DIM : tl.constexpr = 128
583
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
584
+ SAFE_HEAD_DIM : tl.constexpr = True
585
+ BLOCK_M1 : tl.constexpr = 64
586
+ BLOCK_N1 : tl.constexpr = 128
587
+ BLOCK_M2 : tl.constexpr = 128
588
+ BLOCK_N2 : tl.constexpr = 64
589
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
590
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
591
+ INDEX_DTYPE : tl.constexpr = tl.int32
592
+
593
+ SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
594
+ RCP_LN2: tl.constexpr = 1.44269504
595
+ Q_LEN = ks0
596
+ KV_LEN = ks1
597
+
598
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
599
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
600
+
601
+ qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd
602
+ do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod
603
+ # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
604
+ tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
605
+
606
+ # The minimum is needed to handle the case where we run with a super large
607
+ # SPARSE_BLOCK_SIZE (i.e. no block-mask!)
608
+ hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1))
609
+
610
+ for start_m in range(0, hi):
611
+ dk, dv = bwd_dkdv_block_mn(
612
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8,
613
+ dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
614
+ off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
615
+ stride_qm, stride_qd, stride_dom, stride_dod,
616
+ q_indices, sparse_q_num_blocks,
617
+ MATMUL_PRECISION, RCP_LN2,
618
+ IS_FULL_BLOCKS,
619
+ )
620
+ # Increment pointers.
621
+ offset = get_offset_for_next_block(
622
+ start_m, q_indices, sparse_q_num_blocks,
623
+ SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS
624
+ )
625
+
626
+ qT_ptrs += offset * stride_qm
627
+ do_ptrs += offset * stride_dom
628
+ offs_m1 += offset
629
+
630
+ return dk, dv
631
+
632
+
633
+ @triton.jit
634
+ def bwd_dkdv_block_mn(
635
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8,
636
+ dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
637
+ off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
638
+ stride_qm, stride_qd, stride_dom, stride_dod,
639
+ q_indices, sparse_q_num_blocks,
640
+ MATMUL_PRECISION, RCP_LN2,
641
+ IS_FULL_BLOCKS,
642
+ ):
643
+ PRESCALE_QK : tl.constexpr = False
644
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
645
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
646
+ WRITE_DQ : tl.constexpr = True
647
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
648
+ OUTPUT_MAX : tl.constexpr = False
649
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
650
+ IS_DIVISIBLE : tl.constexpr = False
651
+ SM_SCALE : tl.constexpr = 0.08838834764831843
652
+ GQA_SHARED_HEADS : tl.constexpr = 4
653
+ HAS_FULL_BLOCKS : tl.constexpr = True
654
+ QK_HEAD_DIM : tl.constexpr = 128
655
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
656
+ V_HEAD_DIM : tl.constexpr = 128
657
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
658
+ SAFE_HEAD_DIM : tl.constexpr = True
659
+ BLOCK_M1 : tl.constexpr = 64
660
+ BLOCK_N1 : tl.constexpr = 128
661
+ BLOCK_M2 : tl.constexpr = 128
662
+ BLOCK_N2 : tl.constexpr = 64
663
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
664
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
665
+ INDEX_DTYPE : tl.constexpr = tl.int32
666
+
667
+
668
+ # NB reversed order since Q is transposed
669
+ qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN)
670
+ # Load LSE before computing qk to reduce pipeline stall.
671
+ if IS_DIVISIBLE:
672
+ lse = tl.load(LSE + offs_m1)
673
+ else:
674
+ lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN)
675
+ lse = tl.where(lse == -float("inf"), 0.0, lse)
676
+ qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION)
677
+ if not PRESCALE_QK:
678
+ qkT *= SM_SCALE
679
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
680
+ m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None)
681
+ # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim
682
+ # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary
683
+ n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None)
684
+
685
+ pre_mod_scores = qkT
686
+ tmp40 = (qkT)
687
+ post_mod_scores = tmp40
688
+
689
+
690
+
691
+ if not IS_DIVISIBLE:
692
+ post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf"))
693
+
694
+ if not IS_FULL_BLOCKS:
695
+ tmp41 = tl.full([1], False, tl.int1)
696
+ tmp42 = (m)
697
+ tmp43 = (n)
698
+ tmp44 = tmp42 >= tmp43
699
+ tmp45 = tmp43.to(tl.int64)
700
+ tmp46 = (off_z)
701
+ tmp47 = tl.load(in_ptr16 + tmp46)
702
+ tmp48 = tmp45 < tmp47
703
+ tmp49 = tmp42.to(tl.int64)
704
+ tmp50 = tmp49 < tmp47
705
+ tmp51 = tmp48 & tmp50
706
+ tmp52 = tmp44 & tmp51
707
+ tmp53 = tmp41 | tmp52
708
+ tmp54 = ks8
709
+ tmp55 = tmp43 >= tmp54
710
+ tmp56 = (tmp43 % tmp54)
711
+ tmp57 = tl.full([1], 0, tl.int32)
712
+ tmp58 = tmp56 != tmp57
713
+ tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0
714
+ tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0
715
+ tmp61 = tmp59 != tmp60
716
+ tmp62 = tmp58 & tmp61
717
+ tmp63 = tmp56 + tmp54
718
+ tmp64 = tl.where(tmp62, tmp63, tmp56)
719
+ tmp65 = tmp64.to(tl.int64)
720
+ tmp66 = tmp65 < tmp47
721
+ tmp67 = tmp55 & tmp66
722
+ tmp68 = tmp43 - tmp42
723
+ tmp69 = (tmp68 % tmp54)
724
+ tmp70 = tmp69 != tmp57
725
+ tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0
726
+ tmp72 = tmp71 != tmp60
727
+ tmp73 = tmp70 & tmp72
728
+ tmp74 = tmp69 + tmp54
729
+ tmp75 = tl.where(tmp73, tmp74, tmp69)
730
+ tmp76 = tmp75 == tmp57
731
+ tmp77 = tmp67 & tmp76
732
+ tmp78 = tmp53 | tmp77
733
+ mask_mod_output = tmp78
734
+
735
+ # (grads) apply mask for fully masked block
736
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
737
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
738
+ if not PRESCALE_QK:
739
+ post_mod_scores *= RCP_LN2
740
+ pT = tl.math.exp2(post_mod_scores - lse[None, :])
741
+ do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
742
+ # Compute dV.
743
+ ppT = pT
744
+ dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION)
745
+ if IS_DIVISIBLE:
746
+ Di = tl.load(DELTA + offs_m1)
747
+ else:
748
+ Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN)
749
+ # Compute dP and dS.
750
+ dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION)
751
+ dsT = pT * (dpT - Di[None, :])
752
+ # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
753
+ tmp79 = (dsT)
754
+ grad_scores = tmp79
755
+
756
+
757
+
758
+ if not IS_DIVISIBLE:
759
+ grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0)
760
+
761
+ # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
762
+ if not WRITE_DQ:
763
+ idx_b = off_z
764
+ idx_h = off_hq
765
+ idx_m = m
766
+ idx_n = n
767
+ scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN)
768
+
769
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
770
+ dsT = grad_scores
771
+ if not IS_FULL_BLOCKS:
772
+ # (grads) apply mask for partially unmasked block
773
+ dsT = tl.where(mask_mod_output, dsT, 0.0)
774
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
775
+ dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION)
776
+
777
+ return dk, dv
778
+
779
+ # Utility triton funcs
780
+ @triton.jit
781
+ def get_offset_for_next_block(
782
+ loop_iter, col_indices, total_blocks,
783
+ SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
784
+ BLOCKS_ARE_CONTIGUOUS: tl.constexpr
785
+ ):
786
+ if BLOCKS_ARE_CONTIGUOUS:
787
+ return BLOCK
788
+ cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
789
+ cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
790
+ next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
791
+ needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
792
+ jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
793
+ offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
794
+ return offset
795
+
796
+ @triton.jit
797
+ def get_bounded_indices(indices, max_len=None):
798
+ return indices % max_len if max_len is not None else indices
799
+
800
+ @triton.jit
801
+ def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
802
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
803
+ return tl.load(block_ptr)
804
+ elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
805
+ return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
806
+ elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
807
+ return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
808
+ else:
809
+ return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
810
+
811
+ @triton.jit
812
+ def load_checked_2d(
813
+ ptr,
814
+ offs_m,
815
+ offs_n,
816
+ stride_m,
817
+ stride_n,
818
+ IS_DIVISIBLE_M: tl.constexpr,
819
+ IS_DIVISIBLE_N: tl.constexpr,
820
+ M_LEN: tl.constexpr,
821
+ N_LEN: tl.constexpr,
822
+ ):
823
+ # Calculate final pointer if strides are provided
824
+ if stride_m is not None and stride_n is not None:
825
+ ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
826
+
827
+ # Handle all masking cases
828
+ if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
829
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
830
+ elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
831
+ return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
832
+ elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
833
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
834
+ else: # Both divisible
835
+ return tl.load(ptr)
SpecForge-ext/cache/compiled_kernels/35/c35huqp6ngzh67kt32kuxoqpghc32fstv4zogcouzabdxxwta3sl.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.pointwise(
11
+ size_hints={'x': 512},
12
+ filename=__file__,
13
+ triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr0': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_6', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
15
+ min_elem_per_thread=0
16
+ )
17
+ @triton.jit
18
+ def triton_poi_fused__to_copy_6(in_ptr0, out_ptr0, ks0, ks1, ks2, xnumel, XBLOCK : tl.constexpr):
19
+ xoffset = tl.program_id(0) * XBLOCK
20
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
21
+ xmask = xindex < xnumel
22
+ x0 = (xindex % ks0)
23
+ x1 = ((xindex // ks0) % ks1)
24
+ x2 = xindex // ks2
25
+ tmp0 = tl.load(in_ptr0 + (x1 + x0*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1))) + x2*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1)))), xmask, eviction_policy='evict_last')
26
+ tmp1 = tmp0.to(tl.int32)
27
+ tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))) + x2*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1)))), tmp1, xmask)
SpecForge-ext/cache/compiled_kernels/35/db497030eed19cbbd19ee623329ec09ad9ad496274b305a5f803696a7ce87fc1.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 128, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 11, "triton_cache_hash": "IK5RT3JGLTF5PMMUH32NIWB2GXNU6R6CGIZSCRHU3I65YM226KDA"}
SpecForge-ext/cache/compiled_kernels/45/c45n6n4rv3f3q66fpyq53ugyel2jmywhufx7ogpqvuyls4hiicz2.py ADDED
@@ -0,0 +1,1065 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AOT ID: ['9_backward']
2
+ from ctypes import c_void_p, c_long, c_int
3
+ import torch
4
+ import math
5
+ import random
6
+ import os
7
+ import tempfile
8
+ from math import inf, nan
9
+ from cmath import nanj
10
+ from torch._inductor.hooks import run_intermediate_hooks
11
+ from torch._inductor.utils import maybe_profile
12
+ from torch._inductor.codegen.memory_planning import _align as align
13
+ from torch import device, empty_strided
14
+ from torch._inductor.async_compile import AsyncCompile
15
+ from torch._inductor.select_algorithm import extern_kernels
16
+ import triton
17
+ import triton.language as tl
18
+ from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
19
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
20
+
21
+ aten = torch.ops.aten
22
+ inductor_ops = torch.ops.inductor
23
+ _quantized = torch.ops._quantized
24
+ assert_size_stride = torch._C._dynamo.guards.assert_size_stride
25
+ assert_alignment = torch._C._dynamo.guards.assert_alignment
26
+ empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
27
+ empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
28
+ empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
29
+ empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
30
+ empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
31
+ reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
32
+ alloc_from_pool = torch.ops.inductor._alloc_from_pool
33
+ async_compile = AsyncCompile()
34
+ empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
35
+
36
+
37
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/c4/cc4r2l3x4dfli5iih5dji2abfxoclfozqdaqfbdxtcf6lqfpqwdo.py
38
+ # Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros]
39
+ # Source node to ATen node mapping:
40
+ # Graph fragment:
41
+ # %getitem : Tensor "bf16[8, 32, 2048, 128][8388608, 128, 4096, 1]cuda:3" = PlaceHolder[target=getitem]
42
+ # %tangents_1 : Tensor "bf16[8, 32, 2048, 128][8388608, 262144, 128, 1]cuda:3" = PlaceHolder[target=tangents_1]
43
+ # %buf0 : Tensor "bf16[8, 32, 2048][65536, 2048, 1]cuda:3" = PlaceHolder[target=buf0]
44
+ # %full_default : Tensor "f32[8, 32, 2048][65536, 2048, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 32, 2048], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:3, pin_memory: False})
45
+ # %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_1, %primals_3, %primals_5, %getitem, %getitem_1, %tangents_1, %full_default, %fw_graph0, %joint_graph0, (2048, %primals_8, %primals_9, %primals_7, %primals_11, %primals_13, %primals_15, %primals_17, %primals_19, %primals_21, 128, 128, %mask_graph0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_10,)), kwargs = {})
46
+ # return %buf0,%buf1
47
+ triton_red_fused_zeros_0 = async_compile.triton('triton_red_fused_zeros_0', '''
48
+ import triton
49
+ import triton.language as tl
50
+
51
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
52
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
53
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
54
+ triton_helpers.set_driver_to_gpu()
55
+
56
+ @triton_heuristics.reduction(
57
+ size_hints={'x': 524288, 'r0_': 128},
58
+ reduction_hint=ReductionHint.DEFAULT,
59
+ filename=__file__,
60
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr1': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
61
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 4194304, 'r0_': 268435456}}
62
+ )
63
+ @triton.jit
64
+ def triton_red_fused_zeros_0(in_ptr0, in_ptr1, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
65
+ xnumel = 524288
66
+ r0_numel = 128
67
+ rnumel = r0_numel
68
+ RBLOCK: tl.constexpr = R0_BLOCK
69
+ xoffset = tl.program_id(0) * XBLOCK
70
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
71
+ xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
72
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
73
+ rbase = r0_base
74
+ x0 = (xindex % 2048)
75
+ x1 = ((xindex // 2048) % 32)
76
+ x2 = xindex // 65536
77
+ x4 = xindex
78
+ _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
79
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
80
+ r0_index = r0_offset + r0_base
81
+ r0_mask = r0_index < r0_numel
82
+ roffset = r0_offset
83
+ rindex = r0_index
84
+ r0_3 = r0_index
85
+ tmp0 = tl.load(in_ptr0 + (r0_3 + 128*x1 + 4096*x0 + 8388608*x2), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
86
+ tmp1 = tl.load(in_ptr1 + (r0_3 + 128*x4), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
87
+ tmp2 = tmp0 * tmp1
88
+ tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK])
89
+ tmp5 = _tmp4 + tmp3
90
+ _tmp4 = tl.where(r0_mask, tmp5, _tmp4)
91
+ tmp4 = tl.sum(_tmp4, 1)[:, None]
92
+ tmp6 = tmp4.to(tl.float32)
93
+ tmp7 = 0.0
94
+ tmp8 = tmp6 - tmp7
95
+ tl.store(out_ptr1 + (x4), tmp8, None)
96
+ ''', device_str='cuda')
97
+
98
+
99
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/kx/ckxiuwld5taodt6aogxkojllbqa6rvgdkesruwe5ssurjxs2lpmw.py
100
+ # Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros]
101
+ # Source node to ATen node mapping:
102
+ # Graph fragment:
103
+ # %primals_1 : Tensor "bf16[8, 32, 2048, 128][8388608, 128, 4096, 1]cuda:3" = PlaceHolder[target=primals_1]
104
+ # %primals_3 : Tensor "bf16[8, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:3" = PlaceHolder[target=primals_3]
105
+ # %primals_5 : Tensor "bf16[8, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:3" = PlaceHolder[target=primals_5]
106
+ # %getitem_1 : Tensor "f32[8, 32, 2048][65536, 2048, 1]cuda:3" = PlaceHolder[target=getitem_1]
107
+ # %buf1 : Tensor "f32[8, 32, 2048][65536, 2048, 1]cuda:3" = PlaceHolder[target=buf1]
108
+ # %tangents_1 : Tensor "bf16[8, 32, 2048, 128][8388608, 262144, 128, 1]cuda:3" = PlaceHolder[target=tangents_1]
109
+ # %getitem_3 : Tensor "bf16[8, 32, 2048, 128][8388608, 128, 4096, 1]cuda:3" = PlaceHolder[target=getitem_3]
110
+ # %getitem_5 : Tensor "bf16[8, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:3" = PlaceHolder[target=getitem_5]
111
+ # %primals_9 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:3" = PlaceHolder[target=primals_9]
112
+ # %primals_7 : Tensor "i32[8, 1, 16, s72][16*s72, 16*s72, s72, 1]cuda:3" = PlaceHolder[target=primals_7]
113
+ # %primals_15 : Tensor "i32[8, 1, s56][s56, s56, 1]cuda:3" = PlaceHolder[target=primals_15]
114
+ # %primals_17 : Tensor "i32[8, 1, s84, 16][16*s84, 16*s84, 16, 1]cuda:3" = PlaceHolder[target=primals_17]
115
+ # %primals_11 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:3" = PlaceHolder[target=primals_11]
116
+ # %primals_13 : Tensor "i32[8, 1, 16, s4][16*s4, 16*s4, s4, 1]cuda:3" = PlaceHolder[target=primals_13]
117
+ # %primals_19 : Tensor "i32[8, 1, s99][s99, s99, 1]cuda:3" = PlaceHolder[target=primals_19]
118
+ # %primals_21 : Tensor "i32[8, 1, s6, 16][16*s6, 16*s6, 16, 1]cuda:3" = PlaceHolder[target=primals_21]
119
+ # %primals_10 : Tensor "i64[8][1]cuda:3" = PlaceHolder[target=primals_10]
120
+ # %full_default : Tensor "f32[8, 32, 2048][65536, 2048, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 32, 2048], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:3, pin_memory: False})
121
+ # %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_1, %primals_3, %primals_5, %getitem, %getitem_1, %tangents_1, %full_default, %fw_graph0, %joint_graph0, (2048, %primals_8, %primals_9, %primals_7, %primals_11, %primals_13, %primals_15, %primals_17, %primals_19, %primals_21, 128, 128, %mask_graph0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_10,)), kwargs = {})
122
+ # return %getitem_4
123
+ triton_tem_fused_zeros_1 = async_compile.triton('triton_tem_fused_zeros_1', '''
124
+ import triton
125
+ import triton.language as tl
126
+
127
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
128
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
129
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
130
+
131
+ @triton_heuristics.template(
132
+
133
+ num_stages=3,
134
+ num_warps=8,
135
+ triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]},
136
+ inductor_meta={'kernel_name': 'triton_tem_fused_zeros_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
137
+
138
+ )
139
+ @triton.jit
140
+ def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3):
141
+ PRESCALE_QK : tl.constexpr = False
142
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
143
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
144
+ WRITE_DQ : tl.constexpr = True
145
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
146
+ OUTPUT_MAX : tl.constexpr = False
147
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
148
+ IS_DIVISIBLE : tl.constexpr = False
149
+ SM_SCALE : tl.constexpr = 0.08838834764831843
150
+ GQA_SHARED_HEADS : tl.constexpr = 4
151
+ HAS_FULL_BLOCKS : tl.constexpr = True
152
+ QK_HEAD_DIM : tl.constexpr = 128
153
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
154
+ V_HEAD_DIM : tl.constexpr = 128
155
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
156
+ SAFE_HEAD_DIM : tl.constexpr = True
157
+ BLOCK_M1 : tl.constexpr = 64
158
+ BLOCK_N1 : tl.constexpr = 128
159
+ BLOCK_M2 : tl.constexpr = 128
160
+ BLOCK_N2 : tl.constexpr = 64
161
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
162
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
163
+ INDEX_DTYPE : tl.constexpr = tl.int32
164
+ Q = arg_Q
165
+ K = arg_K
166
+ V = arg_V
167
+ LSE = arg_LSE
168
+ DELTA = arg_DELTA
169
+ DO = arg_DO
170
+ DQ = arg_DQ
171
+ DV = arg_DV
172
+ KV_NUM_BLKS = arg_KV_NUM_BLKS
173
+ KV_IDX = arg_KV_IDX
174
+ Q_NUM_BLKS = arg_Q_NUM_BLKS
175
+ Q_IDX = arg_Q_IDX
176
+ FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
177
+ FULL_KV_IDX = arg_FULL_KV_IDX
178
+ FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS
179
+ FULL_Q_IDX = arg_FULL_Q_IDX
180
+
181
+ # Sub notation for this kernel:
182
+ #
183
+ # Q: Query, K: Key, V: Value
184
+ # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype)
185
+ # DELTA: Precomputed sum(OUT*DO, axis=-1)
186
+ # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value
187
+ # DK: Derivative of Key, is the written to via the store_output call due to some limitations with
188
+ # inductor codegen
189
+ # M: Number of queries, N: Number of keys/values
190
+ # QK_HEAD_DIM: The dimension of the query and key embeddings
191
+ # V_HEAD_DIM: The dimension of the value embeddings
192
+ # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim
193
+ # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
194
+ # (Modifiable) Performance tuning options
195
+ # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block.
196
+ # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V.
197
+ # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q.
198
+ # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block.
199
+ #
200
+ # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
201
+ # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
202
+ # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
203
+ # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query.
204
+ # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query.
205
+ # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
206
+ # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
207
+ # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query.
208
+ # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query.
209
+
210
+ # The below are kernel options that can be applied for certain score_mods,
211
+ # or involve a numerics vs. perf tradeoff
212
+ # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
213
+ # about 20% more numerical error, but slightly faster.
214
+
215
+ # Define strides of inputs
216
+ stride_qz, stride_qh, stride_qm, stride_qd = 8388608, 128, 4096, 1
217
+ stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks0, 128*ks0, 128, 1
218
+ stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks0, 128*ks0, 128, 1
219
+ stride_doz, stride_doh, stride_dom, stride_dod = 8388608, 262144, 128, 1
220
+
221
+ stride_dqz, stride_dqh, stride_dqm, stride_dqd = 8388608, 128, 4096, 1
222
+ stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks0, 128*ks0, 128, 1
223
+
224
+ ZQ = 8
225
+ HQ = 32
226
+ HKV = 8
227
+ Q_LEN = 2048
228
+ ZKV = 8
229
+ KV_LEN = ks0
230
+
231
+ MATMUL_PRECISION = Q.dtype.element_ty
232
+
233
+ pid = tl.program_id(0).to(INDEX_DTYPE)
234
+ NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1)
235
+ NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2)
236
+
237
+ off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx
238
+ off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx
239
+ off_zkv = off_zq % ZKV # kv batch idx
240
+
241
+ SPARSE_Z = 8
242
+ SPARSE_HQ = 1
243
+
244
+ sparse_idx_z = off_zq % SPARSE_Z
245
+
246
+ k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64)
247
+ v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64)
248
+ # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
249
+ # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
250
+ dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64)
251
+
252
+ # offset K, V, DV pointers for batch/kv-head
253
+ K += k_adj
254
+ V += v_adj
255
+ DV += dv_adj
256
+
257
+ RCP_LN2 = 1.44269504
258
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
259
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
260
+
261
+ if pid >= NUM_KV_BLOCKS:
262
+ off_pid = pid - NUM_KV_BLOCKS
263
+ # THIS BLOCK DOES DQ
264
+ SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2)
265
+ SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
266
+ off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS
267
+ start_m2_block = off_pid % NUM_Q_BLOCKS
268
+ off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE
269
+ stride_kv_num_blks_h = 16
270
+ stride_kv_idx_h = 16*ks1
271
+ stride_kv_idx_m = ks1
272
+
273
+ sparse_idx_hq2 = off_hq2 % SPARSE_HQ
274
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2
275
+
276
+ sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask
277
+ sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950
278
+
279
+ # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
280
+ q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64)
281
+ do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64)
282
+ dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64)
283
+ off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64)
284
+
285
+ Q2 = Q + q_adj2
286
+ DO2 = DO + do_adj2
287
+ # TODO: This does not work if DQ is not the same layout as Q (for example,
288
+ # if Q is broadcasted)
289
+ DQ2 = DQ + dq_adj2
290
+ LSE2 = LSE + off_chz2
291
+ DELTA2 = DELTA + off_chz2
292
+
293
+ # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32)
294
+ dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
295
+
296
+ start_m2 = start_m2_block * BLOCK_M2
297
+ offs_m2 = start_m2 + tl.arange(0, BLOCK_M2)
298
+
299
+ # load Q and do: they stay in SRAM throughout the inner loop.
300
+ q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
301
+ do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
302
+
303
+ if PRESCALE_QK:
304
+ q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
305
+
306
+ if IS_DIVISIBLE:
307
+ Di = tl.load(DELTA2 + offs_m2)
308
+ lse = tl.load(LSE2 + offs_m2)
309
+ else:
310
+ Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN)
311
+ lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN)
312
+ lse = tl.where(lse == -float("inf"), 0.0, lse)
313
+ lse = lse[:, None]
314
+
315
+ # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
316
+ # KV_IDX and KV_NUM_BLKS are always contiguous.
317
+ kv_indices = KV_IDX + sparse_kv_idx_offset
318
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
319
+ sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
320
+
321
+ offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
322
+ dq = bwd_dq_inner(
323
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
324
+ K, V,
325
+ dq, q, do, Di, lse,
326
+ off_zq, off_hq2, offs_m2, offs_n2,
327
+ stride_kn, stride_kd, stride_vn, stride_vd,
328
+ kv_indices, sparse_kv_num_blocks,
329
+ MATMUL_PRECISION,
330
+ IS_FULL_BLOCKS=False,
331
+ )
332
+
333
+ if HAS_FULL_BLOCKS:
334
+ # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
335
+ # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
336
+ kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
337
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
338
+ sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
339
+
340
+ offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
341
+ dq = bwd_dq_inner(
342
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
343
+ K, V,
344
+ dq, q, do, Di, lse,
345
+ off_zq, off_hq2, offs_m2, offs_n2,
346
+ stride_kn, stride_kd, stride_vn, stride_vd,
347
+ kv_indices, sparse_kv_num_blocks,
348
+ MATMUL_PRECISION,
349
+ IS_FULL_BLOCKS=True,
350
+ )
351
+
352
+ # Write back dQ.
353
+ dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd
354
+ dq *= SM_SCALE
355
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
356
+ tl.store(dq_ptrs, dq)
357
+ else:
358
+ tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM))
359
+ else:
360
+ # THIS BLOCK DOES DK & DV
361
+ SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
362
+ SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1)
363
+
364
+ pid_mask = pid // SPARSE_KV_MULTIPLE
365
+
366
+ stride_q_num_blks_h = ks2
367
+ stride_q_idx_h = 16*ks3
368
+ stride_q_idx_n = 16
369
+
370
+
371
+ dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
372
+ dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
373
+
374
+ start_n1 = pid * BLOCK_N1
375
+ offs_n1 = start_n1 + tl.arange(0, BLOCK_N1)
376
+
377
+ # load K and V: they stay in SRAM throughout the inner loop.
378
+ k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
379
+ v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
380
+
381
+ if PRESCALE_QK:
382
+ k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
383
+
384
+ for off_g in range(0, GQA_SHARED_HEADS):
385
+ off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g
386
+
387
+ # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
388
+ q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64)
389
+ do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64)
390
+ dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64)
391
+ off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64)
392
+
393
+ Q1 = Q + q_adj1
394
+ DO1 = DO + do_adj1
395
+ # TODO: This does not work if DQ is not the same layout as Q (for example,
396
+ # if Q is broadcasted)
397
+ LSE1 = LSE + off_chz1
398
+ DELTA1 = DELTA + off_chz1
399
+
400
+ sparse_idx_hq1 = off_hq1 % SPARSE_HQ
401
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1
402
+
403
+ sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask
404
+ sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950
405
+
406
+ # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
407
+ # Q_IDX and Q_NUM_BLKS are always contiguous.
408
+ q_indices = Q_IDX + sparse_q_idx_offset
409
+ q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
410
+ sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset)
411
+
412
+ offs_m1 = q_start + tl.arange(0, BLOCK_M1)
413
+ dk, dv = bwd_dkdv_inner(
414
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
415
+ Q1, DO1, DELTA1, LSE1,
416
+ dk, dv, k, v,
417
+ off_zq, off_hq1, offs_n1, offs_m1,
418
+ stride_qm, stride_qd, stride_dom, stride_dod,
419
+ q_indices, sparse_q_num_blocks,
420
+ MATMUL_PRECISION,
421
+ IS_FULL_BLOCKS=False,
422
+ )
423
+
424
+
425
+ if HAS_FULL_BLOCKS:
426
+ # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
427
+ # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous.
428
+ q_indices = FULL_Q_IDX + sparse_q_idx_offset
429
+ q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
430
+ sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset)
431
+
432
+ offs_m1 = q_start + tl.arange(0, BLOCK_M1)
433
+ dk, dv = bwd_dkdv_inner(
434
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
435
+ Q1, DO1, DELTA1, LSE1,
436
+ dk, dv, k, v,
437
+ off_zq, off_hq1, offs_n1, offs_m1,
438
+ stride_qm, stride_qd, stride_dom, stride_dod,
439
+ q_indices, sparse_q_num_blocks,
440
+ MATMUL_PRECISION,
441
+ IS_FULL_BLOCKS=True,
442
+ )
443
+
444
+ # Write back dV and dK.
445
+ dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd
446
+
447
+ index_n = offs_n1[:, None]
448
+ index_k = offs_k[None, :]
449
+ index_v = offs_v[None, :]
450
+
451
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
452
+ tl.store(dv_ptrs, dv)
453
+ else:
454
+ tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM))
455
+
456
+ dk *= SM_SCALE
457
+
458
+ if SAFE_HEAD_DIM:
459
+ mask = index_n < KV_LEN
460
+ else:
461
+ mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM)
462
+
463
+ # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
464
+ # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
465
+ tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED])
466
+ xindex = index_k + 128*index_n + 128*off_hkv*ks0 + 1024*off_zq*ks0
467
+ tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask)
468
+
469
+ @triton.jit
470
+ def bwd_dq_inner(
471
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
472
+ K, V, # pointers
473
+ dq, q, do, Di, lse,
474
+ off_z, off_hq, offs_m2, offs_n2,
475
+ stride_kn, stride_kd, stride_vn, stride_vd,
476
+ kv_indices, sparse_kv_num_blocks,
477
+ MATMUL_PRECISION,
478
+ IS_FULL_BLOCKS,
479
+ ):
480
+ PRESCALE_QK : tl.constexpr = False
481
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
482
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
483
+ WRITE_DQ : tl.constexpr = True
484
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
485
+ OUTPUT_MAX : tl.constexpr = False
486
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
487
+ IS_DIVISIBLE : tl.constexpr = False
488
+ SM_SCALE : tl.constexpr = 0.08838834764831843
489
+ GQA_SHARED_HEADS : tl.constexpr = 4
490
+ HAS_FULL_BLOCKS : tl.constexpr = True
491
+ QK_HEAD_DIM : tl.constexpr = 128
492
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
493
+ V_HEAD_DIM : tl.constexpr = 128
494
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
495
+ SAFE_HEAD_DIM : tl.constexpr = True
496
+ BLOCK_M1 : tl.constexpr = 64
497
+ BLOCK_N1 : tl.constexpr = 128
498
+ BLOCK_M2 : tl.constexpr = 128
499
+ BLOCK_N2 : tl.constexpr = 64
500
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
501
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
502
+ INDEX_DTYPE : tl.constexpr = tl.int32
503
+
504
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
505
+ RCP_LN2: tl.constexpr = 1.44269504
506
+ Q_LEN = 2048
507
+ KV_LEN = ks0
508
+
509
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
510
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
511
+
512
+ kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd
513
+ vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd
514
+ # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
515
+ tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
516
+
517
+ hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1))
518
+
519
+ for start_n in range(0, hi):
520
+ dq = bwd_dq_block_mn(
521
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
522
+ dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
523
+ off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
524
+ stride_kn, stride_kd, stride_vn, stride_vd,
525
+ kv_indices, sparse_kv_num_blocks,
526
+ MATMUL_PRECISION, RCP_LN2,
527
+ IS_FULL_BLOCKS,
528
+ )
529
+
530
+ # Increment pointers.
531
+ offset = get_offset_for_next_block(
532
+ start_n, kv_indices, sparse_kv_num_blocks,
533
+ SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS
534
+ )
535
+
536
+ kT_ptrs += offset * stride_kn
537
+ vT_ptrs += offset * stride_vn
538
+
539
+ offs_n2 += offset
540
+
541
+ return dq
542
+
543
+
544
+ @triton.jit
545
+ def bwd_dq_block_mn(
546
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
547
+ dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
548
+ off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
549
+ stride_kn, stride_kd, stride_vn, stride_vd,
550
+ kv_indices, sparse_kv_num_blocks,
551
+ MATMUL_PRECISION, RCP_LN2,
552
+ IS_FULL_BLOCKS,
553
+ ):
554
+ PRESCALE_QK : tl.constexpr = False
555
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
556
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
557
+ WRITE_DQ : tl.constexpr = True
558
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
559
+ OUTPUT_MAX : tl.constexpr = False
560
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
561
+ IS_DIVISIBLE : tl.constexpr = False
562
+ SM_SCALE : tl.constexpr = 0.08838834764831843
563
+ GQA_SHARED_HEADS : tl.constexpr = 4
564
+ HAS_FULL_BLOCKS : tl.constexpr = True
565
+ QK_HEAD_DIM : tl.constexpr = 128
566
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
567
+ V_HEAD_DIM : tl.constexpr = 128
568
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
569
+ SAFE_HEAD_DIM : tl.constexpr = True
570
+ BLOCK_M1 : tl.constexpr = 64
571
+ BLOCK_N1 : tl.constexpr = 128
572
+ BLOCK_M2 : tl.constexpr = 128
573
+ BLOCK_N2 : tl.constexpr = 64
574
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
575
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
576
+ INDEX_DTYPE : tl.constexpr = tl.int32
577
+
578
+
579
+ # NB reversed order to since K is transposed
580
+ kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN)
581
+ qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION)
582
+ if not PRESCALE_QK:
583
+ qk *= SM_SCALE
584
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
585
+ pre_mod_scores = qk
586
+ n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None)
587
+ # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim
588
+ # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary
589
+ m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None)
590
+
591
+ tmp0 = (qk)
592
+ post_mod_scores = tmp0
593
+
594
+
595
+
596
+
597
+ if not IS_DIVISIBLE:
598
+ post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf"))
599
+
600
+ if not IS_FULL_BLOCKS:
601
+ tmp1 = tl.full([1], False, tl.int1)
602
+ tmp2 = (m)
603
+ tmp3 = (n)
604
+ tmp4 = tmp2 >= tmp3
605
+ tmp5 = tmp3.to(tl.int64)
606
+ tmp6 = (off_z)
607
+ tmp7 = tl.load(in_ptr16 + tmp6)
608
+ tmp8 = tmp5 < tmp7
609
+ tmp9 = tmp2.to(tl.int64)
610
+ tmp10 = tmp9 < tmp7
611
+ tmp11 = tmp8 & tmp10
612
+ tmp12 = tmp4 & tmp11
613
+ tmp13 = tmp1 | tmp12
614
+ tmp14 = tl.full([1], 2048, tl.int32)
615
+ tmp15 = tmp3 >= tmp14
616
+ tmp16 = (tmp3 % tmp14)
617
+ tmp17 = tl.full([1], 0, tl.int32)
618
+ tmp18 = tmp16 != tmp17
619
+ tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0
620
+ tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0
621
+ tmp21 = tmp19 != tmp20
622
+ tmp22 = tmp18 & tmp21
623
+ tmp23 = tmp16 + tmp14
624
+ tmp24 = tl.where(tmp22, tmp23, tmp16)
625
+ tmp25 = tmp24.to(tl.int64)
626
+ tmp26 = tmp25 < tmp7
627
+ tmp27 = tmp15 & tmp26
628
+ tmp28 = tmp3 - tmp2
629
+ tmp29 = (tmp28 % tmp14)
630
+ tmp30 = tmp29 != tmp17
631
+ tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0
632
+ tmp32 = tmp31 != tmp20
633
+ tmp33 = tmp30 & tmp32
634
+ tmp34 = tmp29 + tmp14
635
+ tmp35 = tl.where(tmp33, tmp34, tmp29)
636
+ tmp36 = tmp35 == tmp17
637
+ tmp37 = tmp27 & tmp36
638
+ tmp38 = tmp13 | tmp37
639
+ mask_mod_output = tmp38
640
+
641
+
642
+ # apply mask for partial masked block
643
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
644
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
645
+ if not PRESCALE_QK:
646
+ post_mod_scores *= RCP_LN2
647
+ p = tl.math.exp2(post_mod_scores - lse)
648
+ # Compute dP and dS.
649
+ # NB reversed order to since V is transposed
650
+ vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN)
651
+
652
+ dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION)
653
+ ds = p * (dp - Di[:, None])
654
+ # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
655
+ tmp39 = (ds)
656
+ grad_scores = tmp39
657
+
658
+
659
+ if not IS_DIVISIBLE:
660
+ grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0)
661
+
662
+ # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
663
+ if WRITE_DQ:
664
+ scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN)
665
+
666
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
667
+ ds = grad_scores
668
+
669
+ if not IS_FULL_BLOCKS:
670
+ # (grads) apply mask for partially unmasked block
671
+ ds = tl.where(mask_mod_output, ds, 0.0)
672
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
673
+ ds = ds.to(MATMUL_PRECISION)
674
+ # Compute dQ.
675
+ dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION)
676
+
677
+ return dq
678
+
679
+
680
+ @triton.jit
681
+ def bwd_dkdv_inner(
682
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
683
+ Q, DO, DELTA, LSE, # pointers
684
+ dk, dv, k, v,
685
+ off_z, off_hq, offs_n1, offs_m1,
686
+ stride_qm, stride_qd, stride_dom, stride_dod,
687
+ q_indices, sparse_q_num_blocks,
688
+ MATMUL_PRECISION,
689
+ IS_FULL_BLOCKS,
690
+ ):
691
+ PRESCALE_QK : tl.constexpr = False
692
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
693
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
694
+ WRITE_DQ : tl.constexpr = True
695
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
696
+ OUTPUT_MAX : tl.constexpr = False
697
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
698
+ IS_DIVISIBLE : tl.constexpr = False
699
+ SM_SCALE : tl.constexpr = 0.08838834764831843
700
+ GQA_SHARED_HEADS : tl.constexpr = 4
701
+ HAS_FULL_BLOCKS : tl.constexpr = True
702
+ QK_HEAD_DIM : tl.constexpr = 128
703
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
704
+ V_HEAD_DIM : tl.constexpr = 128
705
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
706
+ SAFE_HEAD_DIM : tl.constexpr = True
707
+ BLOCK_M1 : tl.constexpr = 64
708
+ BLOCK_N1 : tl.constexpr = 128
709
+ BLOCK_M2 : tl.constexpr = 128
710
+ BLOCK_N2 : tl.constexpr = 64
711
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
712
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
713
+ INDEX_DTYPE : tl.constexpr = tl.int32
714
+
715
+ SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
716
+ RCP_LN2: tl.constexpr = 1.44269504
717
+ Q_LEN = 2048
718
+ KV_LEN = ks0
719
+
720
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
721
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
722
+
723
+ qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd
724
+ do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod
725
+ # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
726
+ tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
727
+
728
+ # The minimum is needed to handle the case where we run with a super large
729
+ # SPARSE_BLOCK_SIZE (i.e. no block-mask!)
730
+ hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1))
731
+
732
+ for start_m in range(0, hi):
733
+ dk, dv = bwd_dkdv_block_mn(
734
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
735
+ dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
736
+ off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
737
+ stride_qm, stride_qd, stride_dom, stride_dod,
738
+ q_indices, sparse_q_num_blocks,
739
+ MATMUL_PRECISION, RCP_LN2,
740
+ IS_FULL_BLOCKS,
741
+ )
742
+ # Increment pointers.
743
+ offset = get_offset_for_next_block(
744
+ start_m, q_indices, sparse_q_num_blocks,
745
+ SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS
746
+ )
747
+
748
+ qT_ptrs += offset * stride_qm
749
+ do_ptrs += offset * stride_dom
750
+ offs_m1 += offset
751
+
752
+ return dk, dv
753
+
754
+
755
+ @triton.jit
756
+ def bwd_dkdv_block_mn(
757
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
758
+ dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
759
+ off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
760
+ stride_qm, stride_qd, stride_dom, stride_dod,
761
+ q_indices, sparse_q_num_blocks,
762
+ MATMUL_PRECISION, RCP_LN2,
763
+ IS_FULL_BLOCKS,
764
+ ):
765
+ PRESCALE_QK : tl.constexpr = False
766
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
767
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
768
+ WRITE_DQ : tl.constexpr = True
769
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
770
+ OUTPUT_MAX : tl.constexpr = False
771
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
772
+ IS_DIVISIBLE : tl.constexpr = False
773
+ SM_SCALE : tl.constexpr = 0.08838834764831843
774
+ GQA_SHARED_HEADS : tl.constexpr = 4
775
+ HAS_FULL_BLOCKS : tl.constexpr = True
776
+ QK_HEAD_DIM : tl.constexpr = 128
777
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
778
+ V_HEAD_DIM : tl.constexpr = 128
779
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
780
+ SAFE_HEAD_DIM : tl.constexpr = True
781
+ BLOCK_M1 : tl.constexpr = 64
782
+ BLOCK_N1 : tl.constexpr = 128
783
+ BLOCK_M2 : tl.constexpr = 128
784
+ BLOCK_N2 : tl.constexpr = 64
785
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
786
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
787
+ INDEX_DTYPE : tl.constexpr = tl.int32
788
+
789
+
790
+ # NB reversed order since Q is transposed
791
+ qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN)
792
+ # Load LSE before computing qk to reduce pipeline stall.
793
+ if IS_DIVISIBLE:
794
+ lse = tl.load(LSE + offs_m1)
795
+ else:
796
+ lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN)
797
+ lse = tl.where(lse == -float("inf"), 0.0, lse)
798
+ qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION)
799
+ if not PRESCALE_QK:
800
+ qkT *= SM_SCALE
801
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
802
+ m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None)
803
+ # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim
804
+ # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary
805
+ n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None)
806
+
807
+ pre_mod_scores = qkT
808
+ tmp40 = (qkT)
809
+ post_mod_scores = tmp40
810
+
811
+
812
+
813
+ if not IS_DIVISIBLE:
814
+ post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf"))
815
+
816
+ if not IS_FULL_BLOCKS:
817
+ tmp41 = tl.full([1], False, tl.int1)
818
+ tmp42 = (m)
819
+ tmp43 = (n)
820
+ tmp44 = tmp42 >= tmp43
821
+ tmp45 = tmp43.to(tl.int64)
822
+ tmp46 = (off_z)
823
+ tmp47 = tl.load(in_ptr16 + tmp46)
824
+ tmp48 = tmp45 < tmp47
825
+ tmp49 = tmp42.to(tl.int64)
826
+ tmp50 = tmp49 < tmp47
827
+ tmp51 = tmp48 & tmp50
828
+ tmp52 = tmp44 & tmp51
829
+ tmp53 = tmp41 | tmp52
830
+ tmp54 = tl.full([1], 2048, tl.int32)
831
+ tmp55 = tmp43 >= tmp54
832
+ tmp56 = (tmp43 % tmp54)
833
+ tmp57 = tl.full([1], 0, tl.int32)
834
+ tmp58 = tmp56 != tmp57
835
+ tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0
836
+ tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0
837
+ tmp61 = tmp59 != tmp60
838
+ tmp62 = tmp58 & tmp61
839
+ tmp63 = tmp56 + tmp54
840
+ tmp64 = tl.where(tmp62, tmp63, tmp56)
841
+ tmp65 = tmp64.to(tl.int64)
842
+ tmp66 = tmp65 < tmp47
843
+ tmp67 = tmp55 & tmp66
844
+ tmp68 = tmp43 - tmp42
845
+ tmp69 = (tmp68 % tmp54)
846
+ tmp70 = tmp69 != tmp57
847
+ tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0
848
+ tmp72 = tmp71 != tmp60
849
+ tmp73 = tmp70 & tmp72
850
+ tmp74 = tmp69 + tmp54
851
+ tmp75 = tl.where(tmp73, tmp74, tmp69)
852
+ tmp76 = tmp75 == tmp57
853
+ tmp77 = tmp67 & tmp76
854
+ tmp78 = tmp53 | tmp77
855
+ mask_mod_output = tmp78
856
+
857
+ # (grads) apply mask for fully masked block
858
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
859
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
860
+ if not PRESCALE_QK:
861
+ post_mod_scores *= RCP_LN2
862
+ pT = tl.math.exp2(post_mod_scores - lse[None, :])
863
+ do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
864
+ # Compute dV.
865
+ ppT = pT
866
+ dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION)
867
+ if IS_DIVISIBLE:
868
+ Di = tl.load(DELTA + offs_m1)
869
+ else:
870
+ Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN)
871
+ # Compute dP and dS.
872
+ dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION)
873
+ dsT = pT * (dpT - Di[None, :])
874
+ # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
875
+ tmp79 = (dsT)
876
+ grad_scores = tmp79
877
+
878
+
879
+
880
+ if not IS_DIVISIBLE:
881
+ grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0)
882
+
883
+ # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
884
+ if not WRITE_DQ:
885
+ idx_b = off_z
886
+ idx_h = off_hq
887
+ idx_m = m
888
+ idx_n = n
889
+ scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN)
890
+
891
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
892
+ dsT = grad_scores
893
+ if not IS_FULL_BLOCKS:
894
+ # (grads) apply mask for partially unmasked block
895
+ dsT = tl.where(mask_mod_output, dsT, 0.0)
896
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
897
+ dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION)
898
+
899
+ return dk, dv
900
+
901
+ # Utility triton funcs
902
+ @triton.jit
903
+ def get_offset_for_next_block(
904
+ loop_iter, col_indices, total_blocks,
905
+ SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
906
+ BLOCKS_ARE_CONTIGUOUS: tl.constexpr
907
+ ):
908
+ if BLOCKS_ARE_CONTIGUOUS:
909
+ return BLOCK
910
+ cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
911
+ cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
912
+ next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
913
+ needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
914
+ jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
915
+ offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
916
+ return offset
917
+
918
+ @triton.jit
919
+ def get_bounded_indices(indices, max_len=None):
920
+ return indices % max_len if max_len is not None else indices
921
+
922
+ @triton.jit
923
+ def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
924
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
925
+ return tl.load(block_ptr)
926
+ elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
927
+ return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
928
+ elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
929
+ return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
930
+ else:
931
+ return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
932
+
933
+ @triton.jit
934
+ def load_checked_2d(
935
+ ptr,
936
+ offs_m,
937
+ offs_n,
938
+ stride_m,
939
+ stride_n,
940
+ IS_DIVISIBLE_M: tl.constexpr,
941
+ IS_DIVISIBLE_N: tl.constexpr,
942
+ M_LEN: tl.constexpr,
943
+ N_LEN: tl.constexpr,
944
+ ):
945
+ # Calculate final pointer if strides are provided
946
+ if stride_m is not None and stride_n is not None:
947
+ ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
948
+
949
+ # Handle all masking cases
950
+ if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
951
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
952
+ elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
953
+ return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
954
+ elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
955
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
956
+ else: # Both divisible
957
+ return tl.load(ptr)
958
+ ''', device_str='cuda')
959
+
960
+
961
+ async_compile.wait(globals())
962
+ del async_compile
963
+
964
+ class Runner:
965
+ def __init__(self, partitions):
966
+ self.partitions = partitions
967
+
968
+ def recursively_apply_fns(self, fns):
969
+ new_callables = []
970
+ for fn, c in zip(fns, self.partitions):
971
+ new_callables.append(fn(c))
972
+ self.partitions = new_callables
973
+
974
+ def call(self, args):
975
+ primals_8, primals_6, primals_12, primals_14, primals_16, primals_18, primals_20, primals_1, primals_3, primals_5, primals_7, primals_9, primals_10, primals_11, primals_13, primals_15, primals_17, primals_19, primals_21, getitem, getitem_1, tangents_1 = args
976
+ args.clear()
977
+ s0 = primals_8
978
+ s72 = primals_6
979
+ s4 = primals_12
980
+ s56 = primals_14
981
+ s84 = primals_16
982
+ s99 = primals_18
983
+ s6 = primals_20
984
+ assert_size_stride(primals_1, (8, 32, 2048, 128), (8388608, 128, 4096, 1))
985
+ assert_size_stride(primals_3, (8, 8, s0, 128), (1024*s0, 128*s0, 128, 1))
986
+ assert_size_stride(primals_5, (8, 8, s0, 128), (1024*s0, 128*s0, 128, 1))
987
+ assert_size_stride(primals_7, (8, 1, 16, s72), (16*s72, 16*s72, s72, 1))
988
+ assert_size_stride(primals_9, (8, 1, 16), (16, 16, 1))
989
+ assert_size_stride(primals_10, (8, ), (1, ))
990
+ assert_size_stride(primals_11, (8, 1, 16), (16, 16, 1))
991
+ assert_size_stride(primals_13, (8, 1, 16, s4), (16*s4, 16*s4, s4, 1))
992
+ assert_size_stride(primals_15, (8, 1, s56), (s56, s56, 1))
993
+ assert_size_stride(primals_17, (8, 1, s84, 16), (16*s84, 16*s84, 16, 1))
994
+ assert_size_stride(primals_19, (8, 1, s99), (s99, s99, 1))
995
+ assert_size_stride(primals_21, (8, 1, s6, 16), (16*s6, 16*s6, 16, 1))
996
+ assert_size_stride(getitem, (8, 32, 2048, 128), (8388608, 128, 4096, 1))
997
+ assert_size_stride(getitem_1, (8, 32, 2048), (65536, 2048, 1))
998
+ assert_size_stride(tangents_1, (8, 32, 2048, 128), (8388608, 262144, 128, 1))
999
+ with torch.cuda._DeviceGuard(3):
1000
+ torch.cuda.set_device(3)
1001
+ buf1 = empty_strided_cuda((8, 32, 2048), (65536, 2048, 1), torch.float32)
1002
+ # Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros]
1003
+ stream3 = get_raw_stream(3)
1004
+ triton_red_fused_zeros_0.run(getitem, tangents_1, buf1, 524288, 128, stream=stream3)
1005
+ del getitem
1006
+ buf3 = empty_strided_cuda((8, 32, 2048, 128), (8388608, 128, 4096, 1), torch.bfloat16)
1007
+ buf4 = empty_strided_cuda((8, 8, s0, 128), (1024*s0, 128*s0, 128, 1), torch.bfloat16)
1008
+ buf5 = empty_strided_cuda((8, 8, s0, 128), (1024*s0, 128*s0, 128, 1), torch.bfloat16)
1009
+ # Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros]
1010
+ stream3 = get_raw_stream(3)
1011
+ triton_tem_fused_zeros_1.run(primals_1, primals_3, primals_5, getitem_1, buf1, tangents_1, buf3, buf4, primals_9, primals_7, primals_15, primals_17, primals_11, primals_13, primals_19, primals_21, primals_10, buf5, s0, s72, s56, s84, 64 + ((127 + s0) // 128), 8, 8, stream=stream3)
1012
+ del buf1
1013
+ del getitem_1
1014
+ del primals_1
1015
+ del primals_10
1016
+ del primals_11
1017
+ del primals_13
1018
+ del primals_15
1019
+ del primals_17
1020
+ del primals_19
1021
+ del primals_21
1022
+ del primals_3
1023
+ del primals_5
1024
+ del primals_7
1025
+ del primals_9
1026
+ del tangents_1
1027
+ return (buf3, None, buf5, None, buf4, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, )
1028
+
1029
+ runner = Runner(partitions=[])
1030
+ call = runner.call
1031
+ recursively_apply_fns = runner.recursively_apply_fns
1032
+
1033
+
1034
+ def benchmark_compiled_module(times=10, repeat=10):
1035
+ from torch._dynamo.testing import rand_strided
1036
+ from torch._inductor.utils import print_performance
1037
+ primals_8 = 4096
1038
+ primals_6 = 32
1039
+ primals_12 = 32
1040
+ primals_14 = 32
1041
+ primals_16 = 32
1042
+ primals_18 = 32
1043
+ primals_20 = 32
1044
+ primals_1 = rand_strided((8, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:3', dtype=torch.bfloat16)
1045
+ primals_3 = rand_strided((8, 8, 4096, 128), (4194304, 524288, 128, 1), device='cuda:3', dtype=torch.bfloat16)
1046
+ primals_5 = rand_strided((8, 8, 4096, 128), (4194304, 524288, 128, 1), device='cuda:3', dtype=torch.bfloat16)
1047
+ primals_7 = rand_strided((8, 1, 16, 32), (512, 512, 32, 1), device='cuda:3', dtype=torch.int32)
1048
+ primals_9 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:3', dtype=torch.int32)
1049
+ primals_10 = rand_strided((8, ), (1, ), device='cuda:3', dtype=torch.int64)
1050
+ primals_11 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:3', dtype=torch.int32)
1051
+ primals_13 = rand_strided((8, 1, 16, 32), (512, 512, 32, 1), device='cuda:3', dtype=torch.int32)
1052
+ primals_15 = rand_strided((8, 1, 32), (32, 32, 1), device='cuda:3', dtype=torch.int32)
1053
+ primals_17 = rand_strided((8, 1, 32, 16), (512, 512, 16, 1), device='cuda:3', dtype=torch.int32)
1054
+ primals_19 = rand_strided((8, 1, 32), (32, 32, 1), device='cuda:3', dtype=torch.int32)
1055
+ primals_21 = rand_strided((8, 1, 32, 16), (512, 512, 16, 1), device='cuda:3', dtype=torch.int32)
1056
+ getitem = rand_strided((8, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:3', dtype=torch.bfloat16)
1057
+ getitem_1 = rand_strided((8, 32, 2048), (65536, 2048, 1), device='cuda:3', dtype=torch.float32)
1058
+ tangents_1 = rand_strided((8, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:3', dtype=torch.bfloat16)
1059
+ fn = lambda: call([primals_8, primals_6, primals_12, primals_14, primals_16, primals_18, primals_20, primals_1, primals_3, primals_5, primals_7, primals_9, primals_10, primals_11, primals_13, primals_15, primals_17, primals_19, primals_21, getitem, getitem_1, tangents_1])
1060
+ return print_performance(fn, times=times, repeat=repeat)
1061
+
1062
+
1063
+ if __name__ == "__main__":
1064
+ from torch._inductor.wrapper_benchmark import compiled_module_main
1065
+ compiled_module_main('None', benchmark_compiled_module)
SpecForge-ext/cache/compiled_kernels/45/c45wcuv4sn2ie6lowss5cksehnjgehlinebmvyopum4so5p257dk.py ADDED
@@ -0,0 +1,835 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+
9
+ @triton_heuristics.template(
10
+
11
+ num_stages=3,
12
+ num_warps=8,
13
+ triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32', 'ks6': 'i32', 'ks7': 'i32', 'ks8': 'i32'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'kernel_name': 'triton_tem_fused_zeros_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
15
+
16
+ )
17
+ @triton.jit
18
+ def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8):
19
+ PRESCALE_QK : tl.constexpr = False
20
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
21
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
22
+ WRITE_DQ : tl.constexpr = True
23
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
24
+ OUTPUT_MAX : tl.constexpr = False
25
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
26
+ IS_DIVISIBLE : tl.constexpr = False
27
+ SM_SCALE : tl.constexpr = 0.08838834764831843
28
+ GQA_SHARED_HEADS : tl.constexpr = 4
29
+ HAS_FULL_BLOCKS : tl.constexpr = True
30
+ QK_HEAD_DIM : tl.constexpr = 128
31
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
32
+ V_HEAD_DIM : tl.constexpr = 128
33
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
34
+ SAFE_HEAD_DIM : tl.constexpr = True
35
+ BLOCK_M1 : tl.constexpr = 64
36
+ BLOCK_N1 : tl.constexpr = 128
37
+ BLOCK_M2 : tl.constexpr = 128
38
+ BLOCK_N2 : tl.constexpr = 64
39
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
40
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
41
+ INDEX_DTYPE : tl.constexpr = tl.int32
42
+ Q = arg_Q
43
+ K = arg_K
44
+ V = arg_V
45
+ LSE = arg_LSE
46
+ DELTA = arg_DELTA
47
+ DO = arg_DO
48
+ DQ = arg_DQ
49
+ DV = arg_DV
50
+ KV_NUM_BLKS = arg_KV_NUM_BLKS
51
+ KV_IDX = arg_KV_IDX
52
+ Q_NUM_BLKS = arg_Q_NUM_BLKS
53
+ Q_IDX = arg_Q_IDX
54
+ FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
55
+ FULL_KV_IDX = arg_FULL_KV_IDX
56
+ FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS
57
+ FULL_Q_IDX = arg_FULL_Q_IDX
58
+
59
+ # Sub notation for this kernel:
60
+ #
61
+ # Q: Query, K: Key, V: Value
62
+ # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype)
63
+ # DELTA: Precomputed sum(OUT*DO, axis=-1)
64
+ # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value
65
+ # DK: Derivative of Key, is the written to via the store_output call due to some limitations with
66
+ # inductor codegen
67
+ # M: Number of queries, N: Number of keys/values
68
+ # QK_HEAD_DIM: The dimension of the query and key embeddings
69
+ # V_HEAD_DIM: The dimension of the value embeddings
70
+ # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim
71
+ # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
72
+ # (Modifiable) Performance tuning options
73
+ # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block.
74
+ # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V.
75
+ # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q.
76
+ # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block.
77
+ #
78
+ # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
79
+ # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
80
+ # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
81
+ # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query.
82
+ # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query.
83
+ # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
84
+ # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
85
+ # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query.
86
+ # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query.
87
+
88
+ # The below are kernel options that can be applied for certain score_mods,
89
+ # or involve a numerics vs. perf tradeoff
90
+ # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
91
+ # about 20% more numerical error, but slightly faster.
92
+
93
+ # Define strides of inputs
94
+ stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1
95
+ stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128*ks1, 128, 1
96
+ stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128*ks1, 128, 1
97
+ stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1
98
+
99
+ stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1
100
+ stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128*ks1, 128, 1
101
+
102
+ ZQ = 2
103
+ HQ = 32
104
+ HKV = 8
105
+ Q_LEN = ks0
106
+ ZKV = 2
107
+ KV_LEN = ks1
108
+
109
+ MATMUL_PRECISION = Q.dtype.element_ty
110
+
111
+ pid = tl.program_id(0).to(INDEX_DTYPE)
112
+ NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1)
113
+ NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2)
114
+
115
+ off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx
116
+ off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx
117
+ off_zkv = off_zq % ZKV # kv batch idx
118
+
119
+ SPARSE_Z = 2
120
+ SPARSE_HQ = 1
121
+
122
+ sparse_idx_z = off_zq % SPARSE_Z
123
+
124
+ k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64)
125
+ v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64)
126
+ # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
127
+ # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
128
+ dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64)
129
+
130
+ # offset K, V, DV pointers for batch/kv-head
131
+ K += k_adj
132
+ V += v_adj
133
+ DV += dv_adj
134
+
135
+ RCP_LN2 = 1.44269504
136
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
137
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
138
+
139
+ if pid >= NUM_KV_BLOCKS:
140
+ off_pid = pid - NUM_KV_BLOCKS
141
+ # THIS BLOCK DOES DQ
142
+ SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2)
143
+ SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
144
+ off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS
145
+ start_m2_block = off_pid % NUM_Q_BLOCKS
146
+ off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE
147
+ stride_kv_num_blks_h = ks2
148
+ stride_kv_idx_h = ks3*ks4
149
+ stride_kv_idx_m = ks4
150
+
151
+ sparse_idx_hq2 = off_hq2 % SPARSE_HQ
152
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2
153
+
154
+ sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask
155
+ sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950
156
+
157
+ # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
158
+ q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64)
159
+ do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64)
160
+ dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64)
161
+ off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64)
162
+
163
+ Q2 = Q + q_adj2
164
+ DO2 = DO + do_adj2
165
+ # TODO: This does not work if DQ is not the same layout as Q (for example,
166
+ # if Q is broadcasted)
167
+ DQ2 = DQ + dq_adj2
168
+ LSE2 = LSE + off_chz2
169
+ DELTA2 = DELTA + off_chz2
170
+
171
+ # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32)
172
+ dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
173
+
174
+ start_m2 = start_m2_block * BLOCK_M2
175
+ offs_m2 = start_m2 + tl.arange(0, BLOCK_M2)
176
+
177
+ # load Q and do: they stay in SRAM throughout the inner loop.
178
+ q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
179
+ do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
180
+
181
+ if PRESCALE_QK:
182
+ q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
183
+
184
+ if IS_DIVISIBLE:
185
+ Di = tl.load(DELTA2 + offs_m2)
186
+ lse = tl.load(LSE2 + offs_m2)
187
+ else:
188
+ Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN)
189
+ lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN)
190
+ lse = tl.where(lse == -float("inf"), 0.0, lse)
191
+ lse = lse[:, None]
192
+
193
+ # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
194
+ # KV_IDX and KV_NUM_BLKS are always contiguous.
195
+ kv_indices = KV_IDX + sparse_kv_idx_offset
196
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
197
+ sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
198
+
199
+ offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
200
+ dq = bwd_dq_inner(
201
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8,
202
+ K, V,
203
+ dq, q, do, Di, lse,
204
+ off_zq, off_hq2, offs_m2, offs_n2,
205
+ stride_kn, stride_kd, stride_vn, stride_vd,
206
+ kv_indices, sparse_kv_num_blocks,
207
+ MATMUL_PRECISION,
208
+ IS_FULL_BLOCKS=False,
209
+ )
210
+
211
+ if HAS_FULL_BLOCKS:
212
+ # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
213
+ # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
214
+ kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
215
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
216
+ sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
217
+
218
+ offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
219
+ dq = bwd_dq_inner(
220
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8,
221
+ K, V,
222
+ dq, q, do, Di, lse,
223
+ off_zq, off_hq2, offs_m2, offs_n2,
224
+ stride_kn, stride_kd, stride_vn, stride_vd,
225
+ kv_indices, sparse_kv_num_blocks,
226
+ MATMUL_PRECISION,
227
+ IS_FULL_BLOCKS=True,
228
+ )
229
+
230
+ # Write back dQ.
231
+ dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd
232
+ dq *= SM_SCALE
233
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
234
+ tl.store(dq_ptrs, dq)
235
+ else:
236
+ tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM))
237
+ else:
238
+ # THIS BLOCK DOES DK & DV
239
+ SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
240
+ SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1)
241
+
242
+ pid_mask = pid // SPARSE_KV_MULTIPLE
243
+
244
+ stride_q_num_blks_h = ks5
245
+ stride_q_idx_h = ks6*ks7
246
+ stride_q_idx_n = ks6
247
+
248
+
249
+ dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
250
+ dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
251
+
252
+ start_n1 = pid * BLOCK_N1
253
+ offs_n1 = start_n1 + tl.arange(0, BLOCK_N1)
254
+
255
+ # load K and V: they stay in SRAM throughout the inner loop.
256
+ k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
257
+ v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
258
+
259
+ if PRESCALE_QK:
260
+ k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
261
+
262
+ for off_g in range(0, GQA_SHARED_HEADS):
263
+ off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g
264
+
265
+ # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
266
+ q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64)
267
+ do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64)
268
+ dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64)
269
+ off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64)
270
+
271
+ Q1 = Q + q_adj1
272
+ DO1 = DO + do_adj1
273
+ # TODO: This does not work if DQ is not the same layout as Q (for example,
274
+ # if Q is broadcasted)
275
+ LSE1 = LSE + off_chz1
276
+ DELTA1 = DELTA + off_chz1
277
+
278
+ sparse_idx_hq1 = off_hq1 % SPARSE_HQ
279
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1
280
+
281
+ sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask
282
+ sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950
283
+
284
+ # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
285
+ # Q_IDX and Q_NUM_BLKS are always contiguous.
286
+ q_indices = Q_IDX + sparse_q_idx_offset
287
+ q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
288
+ sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset)
289
+
290
+ offs_m1 = q_start + tl.arange(0, BLOCK_M1)
291
+ dk, dv = bwd_dkdv_inner(
292
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8,
293
+ Q1, DO1, DELTA1, LSE1,
294
+ dk, dv, k, v,
295
+ off_zq, off_hq1, offs_n1, offs_m1,
296
+ stride_qm, stride_qd, stride_dom, stride_dod,
297
+ q_indices, sparse_q_num_blocks,
298
+ MATMUL_PRECISION,
299
+ IS_FULL_BLOCKS=False,
300
+ )
301
+
302
+
303
+ if HAS_FULL_BLOCKS:
304
+ # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
305
+ # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous.
306
+ q_indices = FULL_Q_IDX + sparse_q_idx_offset
307
+ q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
308
+ sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset)
309
+
310
+ offs_m1 = q_start + tl.arange(0, BLOCK_M1)
311
+ dk, dv = bwd_dkdv_inner(
312
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8,
313
+ Q1, DO1, DELTA1, LSE1,
314
+ dk, dv, k, v,
315
+ off_zq, off_hq1, offs_n1, offs_m1,
316
+ stride_qm, stride_qd, stride_dom, stride_dod,
317
+ q_indices, sparse_q_num_blocks,
318
+ MATMUL_PRECISION,
319
+ IS_FULL_BLOCKS=True,
320
+ )
321
+
322
+ # Write back dV and dK.
323
+ dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd
324
+
325
+ index_n = offs_n1[:, None]
326
+ index_k = offs_k[None, :]
327
+ index_v = offs_v[None, :]
328
+
329
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
330
+ tl.store(dv_ptrs, dv)
331
+ else:
332
+ tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM))
333
+
334
+ dk *= SM_SCALE
335
+
336
+ if SAFE_HEAD_DIM:
337
+ mask = index_n < KV_LEN
338
+ else:
339
+ mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM)
340
+
341
+ # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
342
+ # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
343
+ tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED])
344
+ xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1
345
+ tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask)
346
+
347
+ @triton.jit
348
+ def bwd_dq_inner(
349
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8,
350
+ K, V, # pointers
351
+ dq, q, do, Di, lse,
352
+ off_z, off_hq, offs_m2, offs_n2,
353
+ stride_kn, stride_kd, stride_vn, stride_vd,
354
+ kv_indices, sparse_kv_num_blocks,
355
+ MATMUL_PRECISION,
356
+ IS_FULL_BLOCKS,
357
+ ):
358
+ PRESCALE_QK : tl.constexpr = False
359
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
360
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
361
+ WRITE_DQ : tl.constexpr = True
362
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
363
+ OUTPUT_MAX : tl.constexpr = False
364
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
365
+ IS_DIVISIBLE : tl.constexpr = False
366
+ SM_SCALE : tl.constexpr = 0.08838834764831843
367
+ GQA_SHARED_HEADS : tl.constexpr = 4
368
+ HAS_FULL_BLOCKS : tl.constexpr = True
369
+ QK_HEAD_DIM : tl.constexpr = 128
370
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
371
+ V_HEAD_DIM : tl.constexpr = 128
372
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
373
+ SAFE_HEAD_DIM : tl.constexpr = True
374
+ BLOCK_M1 : tl.constexpr = 64
375
+ BLOCK_N1 : tl.constexpr = 128
376
+ BLOCK_M2 : tl.constexpr = 128
377
+ BLOCK_N2 : tl.constexpr = 64
378
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
379
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
380
+ INDEX_DTYPE : tl.constexpr = tl.int32
381
+
382
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
383
+ RCP_LN2: tl.constexpr = 1.44269504
384
+ Q_LEN = ks0
385
+ KV_LEN = ks1
386
+
387
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
388
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
389
+
390
+ kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd
391
+ vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd
392
+ # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
393
+ tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
394
+
395
+ hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1))
396
+
397
+ for start_n in range(0, hi):
398
+ dq = bwd_dq_block_mn(
399
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8,
400
+ dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
401
+ off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
402
+ stride_kn, stride_kd, stride_vn, stride_vd,
403
+ kv_indices, sparse_kv_num_blocks,
404
+ MATMUL_PRECISION, RCP_LN2,
405
+ IS_FULL_BLOCKS,
406
+ )
407
+
408
+ # Increment pointers.
409
+ offset = get_offset_for_next_block(
410
+ start_n, kv_indices, sparse_kv_num_blocks,
411
+ SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS
412
+ )
413
+
414
+ kT_ptrs += offset * stride_kn
415
+ vT_ptrs += offset * stride_vn
416
+
417
+ offs_n2 += offset
418
+
419
+ return dq
420
+
421
+
422
+ @triton.jit
423
+ def bwd_dq_block_mn(
424
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8,
425
+ dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
426
+ off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
427
+ stride_kn, stride_kd, stride_vn, stride_vd,
428
+ kv_indices, sparse_kv_num_blocks,
429
+ MATMUL_PRECISION, RCP_LN2,
430
+ IS_FULL_BLOCKS,
431
+ ):
432
+ PRESCALE_QK : tl.constexpr = False
433
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
434
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
435
+ WRITE_DQ : tl.constexpr = True
436
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
437
+ OUTPUT_MAX : tl.constexpr = False
438
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
439
+ IS_DIVISIBLE : tl.constexpr = False
440
+ SM_SCALE : tl.constexpr = 0.08838834764831843
441
+ GQA_SHARED_HEADS : tl.constexpr = 4
442
+ HAS_FULL_BLOCKS : tl.constexpr = True
443
+ QK_HEAD_DIM : tl.constexpr = 128
444
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
445
+ V_HEAD_DIM : tl.constexpr = 128
446
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
447
+ SAFE_HEAD_DIM : tl.constexpr = True
448
+ BLOCK_M1 : tl.constexpr = 64
449
+ BLOCK_N1 : tl.constexpr = 128
450
+ BLOCK_M2 : tl.constexpr = 128
451
+ BLOCK_N2 : tl.constexpr = 64
452
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
453
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
454
+ INDEX_DTYPE : tl.constexpr = tl.int32
455
+
456
+
457
+ # NB reversed order to since K is transposed
458
+ kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN)
459
+ qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION)
460
+ if not PRESCALE_QK:
461
+ qk *= SM_SCALE
462
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
463
+ pre_mod_scores = qk
464
+ n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None)
465
+ # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim
466
+ # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary
467
+ m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None)
468
+
469
+ tmp0 = (qk)
470
+ post_mod_scores = tmp0
471
+
472
+
473
+
474
+
475
+ if not IS_DIVISIBLE:
476
+ post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf"))
477
+
478
+ if not IS_FULL_BLOCKS:
479
+ tmp1 = tl.full([1], False, tl.int1)
480
+ tmp2 = (m)
481
+ tmp3 = (n)
482
+ tmp4 = tmp2 >= tmp3
483
+ tmp5 = tmp3.to(tl.int64)
484
+ tmp6 = (off_z)
485
+ tmp7 = tl.load(in_ptr16 + tmp6)
486
+ tmp8 = tmp5 < tmp7
487
+ tmp9 = tmp2.to(tl.int64)
488
+ tmp10 = tmp9 < tmp7
489
+ tmp11 = tmp8 & tmp10
490
+ tmp12 = tmp4 & tmp11
491
+ tmp13 = tmp1 | tmp12
492
+ tmp14 = ks8
493
+ tmp15 = tmp3 >= tmp14
494
+ tmp16 = (tmp3 % tmp14)
495
+ tmp17 = tl.full([1], 0, tl.int32)
496
+ tmp18 = tmp16 != tmp17
497
+ tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0
498
+ tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0
499
+ tmp21 = tmp19 != tmp20
500
+ tmp22 = tmp18 & tmp21
501
+ tmp23 = tmp16 + tmp14
502
+ tmp24 = tl.where(tmp22, tmp23, tmp16)
503
+ tmp25 = tmp24.to(tl.int64)
504
+ tmp26 = tmp25 < tmp7
505
+ tmp27 = tmp15 & tmp26
506
+ tmp28 = tmp3 - tmp2
507
+ tmp29 = (tmp28 % tmp14)
508
+ tmp30 = tmp29 != tmp17
509
+ tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0
510
+ tmp32 = tmp31 != tmp20
511
+ tmp33 = tmp30 & tmp32
512
+ tmp34 = tmp29 + tmp14
513
+ tmp35 = tl.where(tmp33, tmp34, tmp29)
514
+ tmp36 = tmp35 == tmp17
515
+ tmp37 = tmp27 & tmp36
516
+ tmp38 = tmp13 | tmp37
517
+ mask_mod_output = tmp38
518
+
519
+
520
+ # apply mask for partial masked block
521
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
522
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
523
+ if not PRESCALE_QK:
524
+ post_mod_scores *= RCP_LN2
525
+ p = tl.math.exp2(post_mod_scores - lse)
526
+ # Compute dP and dS.
527
+ # NB reversed order to since V is transposed
528
+ vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN)
529
+
530
+ dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION)
531
+ ds = p * (dp - Di[:, None])
532
+ # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
533
+ tmp39 = (ds)
534
+ grad_scores = tmp39
535
+
536
+
537
+ if not IS_DIVISIBLE:
538
+ grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0)
539
+
540
+ # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
541
+ if WRITE_DQ:
542
+ scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN)
543
+
544
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
545
+ ds = grad_scores
546
+
547
+ if not IS_FULL_BLOCKS:
548
+ # (grads) apply mask for partially unmasked block
549
+ ds = tl.where(mask_mod_output, ds, 0.0)
550
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
551
+ ds = ds.to(MATMUL_PRECISION)
552
+ # Compute dQ.
553
+ dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION)
554
+
555
+ return dq
556
+
557
+
558
+ @triton.jit
559
+ def bwd_dkdv_inner(
560
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8,
561
+ Q, DO, DELTA, LSE, # pointers
562
+ dk, dv, k, v,
563
+ off_z, off_hq, offs_n1, offs_m1,
564
+ stride_qm, stride_qd, stride_dom, stride_dod,
565
+ q_indices, sparse_q_num_blocks,
566
+ MATMUL_PRECISION,
567
+ IS_FULL_BLOCKS,
568
+ ):
569
+ PRESCALE_QK : tl.constexpr = False
570
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
571
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
572
+ WRITE_DQ : tl.constexpr = True
573
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
574
+ OUTPUT_MAX : tl.constexpr = False
575
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
576
+ IS_DIVISIBLE : tl.constexpr = False
577
+ SM_SCALE : tl.constexpr = 0.08838834764831843
578
+ GQA_SHARED_HEADS : tl.constexpr = 4
579
+ HAS_FULL_BLOCKS : tl.constexpr = True
580
+ QK_HEAD_DIM : tl.constexpr = 128
581
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
582
+ V_HEAD_DIM : tl.constexpr = 128
583
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
584
+ SAFE_HEAD_DIM : tl.constexpr = True
585
+ BLOCK_M1 : tl.constexpr = 64
586
+ BLOCK_N1 : tl.constexpr = 128
587
+ BLOCK_M2 : tl.constexpr = 128
588
+ BLOCK_N2 : tl.constexpr = 64
589
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
590
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
591
+ INDEX_DTYPE : tl.constexpr = tl.int32
592
+
593
+ SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
594
+ RCP_LN2: tl.constexpr = 1.44269504
595
+ Q_LEN = ks0
596
+ KV_LEN = ks1
597
+
598
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
599
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
600
+
601
+ qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd
602
+ do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod
603
+ # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
604
+ tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
605
+
606
+ # The minimum is needed to handle the case where we run with a super large
607
+ # SPARSE_BLOCK_SIZE (i.e. no block-mask!)
608
+ hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1))
609
+
610
+ for start_m in range(0, hi):
611
+ dk, dv = bwd_dkdv_block_mn(
612
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8,
613
+ dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
614
+ off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
615
+ stride_qm, stride_qd, stride_dom, stride_dod,
616
+ q_indices, sparse_q_num_blocks,
617
+ MATMUL_PRECISION, RCP_LN2,
618
+ IS_FULL_BLOCKS,
619
+ )
620
+ # Increment pointers.
621
+ offset = get_offset_for_next_block(
622
+ start_m, q_indices, sparse_q_num_blocks,
623
+ SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS
624
+ )
625
+
626
+ qT_ptrs += offset * stride_qm
627
+ do_ptrs += offset * stride_dom
628
+ offs_m1 += offset
629
+
630
+ return dk, dv
631
+
632
+
633
+ @triton.jit
634
+ def bwd_dkdv_block_mn(
635
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8,
636
+ dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
637
+ off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
638
+ stride_qm, stride_qd, stride_dom, stride_dod,
639
+ q_indices, sparse_q_num_blocks,
640
+ MATMUL_PRECISION, RCP_LN2,
641
+ IS_FULL_BLOCKS,
642
+ ):
643
+ PRESCALE_QK : tl.constexpr = False
644
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
645
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
646
+ WRITE_DQ : tl.constexpr = True
647
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
648
+ OUTPUT_MAX : tl.constexpr = False
649
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
650
+ IS_DIVISIBLE : tl.constexpr = False
651
+ SM_SCALE : tl.constexpr = 0.08838834764831843
652
+ GQA_SHARED_HEADS : tl.constexpr = 4
653
+ HAS_FULL_BLOCKS : tl.constexpr = True
654
+ QK_HEAD_DIM : tl.constexpr = 128
655
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
656
+ V_HEAD_DIM : tl.constexpr = 128
657
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
658
+ SAFE_HEAD_DIM : tl.constexpr = True
659
+ BLOCK_M1 : tl.constexpr = 64
660
+ BLOCK_N1 : tl.constexpr = 128
661
+ BLOCK_M2 : tl.constexpr = 128
662
+ BLOCK_N2 : tl.constexpr = 64
663
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
664
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
665
+ INDEX_DTYPE : tl.constexpr = tl.int32
666
+
667
+
668
+ # NB reversed order since Q is transposed
669
+ qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN)
670
+ # Load LSE before computing qk to reduce pipeline stall.
671
+ if IS_DIVISIBLE:
672
+ lse = tl.load(LSE + offs_m1)
673
+ else:
674
+ lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN)
675
+ lse = tl.where(lse == -float("inf"), 0.0, lse)
676
+ qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION)
677
+ if not PRESCALE_QK:
678
+ qkT *= SM_SCALE
679
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
680
+ m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None)
681
+ # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim
682
+ # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary
683
+ n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None)
684
+
685
+ pre_mod_scores = qkT
686
+ tmp40 = (qkT)
687
+ post_mod_scores = tmp40
688
+
689
+
690
+
691
+ if not IS_DIVISIBLE:
692
+ post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf"))
693
+
694
+ if not IS_FULL_BLOCKS:
695
+ tmp41 = tl.full([1], False, tl.int1)
696
+ tmp42 = (m)
697
+ tmp43 = (n)
698
+ tmp44 = tmp42 >= tmp43
699
+ tmp45 = tmp43.to(tl.int64)
700
+ tmp46 = (off_z)
701
+ tmp47 = tl.load(in_ptr16 + tmp46)
702
+ tmp48 = tmp45 < tmp47
703
+ tmp49 = tmp42.to(tl.int64)
704
+ tmp50 = tmp49 < tmp47
705
+ tmp51 = tmp48 & tmp50
706
+ tmp52 = tmp44 & tmp51
707
+ tmp53 = tmp41 | tmp52
708
+ tmp54 = ks8
709
+ tmp55 = tmp43 >= tmp54
710
+ tmp56 = (tmp43 % tmp54)
711
+ tmp57 = tl.full([1], 0, tl.int32)
712
+ tmp58 = tmp56 != tmp57
713
+ tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0
714
+ tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0
715
+ tmp61 = tmp59 != tmp60
716
+ tmp62 = tmp58 & tmp61
717
+ tmp63 = tmp56 + tmp54
718
+ tmp64 = tl.where(tmp62, tmp63, tmp56)
719
+ tmp65 = tmp64.to(tl.int64)
720
+ tmp66 = tmp65 < tmp47
721
+ tmp67 = tmp55 & tmp66
722
+ tmp68 = tmp43 - tmp42
723
+ tmp69 = (tmp68 % tmp54)
724
+ tmp70 = tmp69 != tmp57
725
+ tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0
726
+ tmp72 = tmp71 != tmp60
727
+ tmp73 = tmp70 & tmp72
728
+ tmp74 = tmp69 + tmp54
729
+ tmp75 = tl.where(tmp73, tmp74, tmp69)
730
+ tmp76 = tmp75 == tmp57
731
+ tmp77 = tmp67 & tmp76
732
+ tmp78 = tmp53 | tmp77
733
+ mask_mod_output = tmp78
734
+
735
+ # (grads) apply mask for fully masked block
736
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
737
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
738
+ if not PRESCALE_QK:
739
+ post_mod_scores *= RCP_LN2
740
+ pT = tl.math.exp2(post_mod_scores - lse[None, :])
741
+ do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
742
+ # Compute dV.
743
+ ppT = pT
744
+ dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION)
745
+ if IS_DIVISIBLE:
746
+ Di = tl.load(DELTA + offs_m1)
747
+ else:
748
+ Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN)
749
+ # Compute dP and dS.
750
+ dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION)
751
+ dsT = pT * (dpT - Di[None, :])
752
+ # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
753
+ tmp79 = (dsT)
754
+ grad_scores = tmp79
755
+
756
+
757
+
758
+ if not IS_DIVISIBLE:
759
+ grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0)
760
+
761
+ # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
762
+ if not WRITE_DQ:
763
+ idx_b = off_z
764
+ idx_h = off_hq
765
+ idx_m = m
766
+ idx_n = n
767
+ scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN)
768
+
769
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
770
+ dsT = grad_scores
771
+ if not IS_FULL_BLOCKS:
772
+ # (grads) apply mask for partially unmasked block
773
+ dsT = tl.where(mask_mod_output, dsT, 0.0)
774
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
775
+ dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION)
776
+
777
+ return dk, dv
778
+
779
+ # Utility triton funcs
780
+ @triton.jit
781
+ def get_offset_for_next_block(
782
+ loop_iter, col_indices, total_blocks,
783
+ SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
784
+ BLOCKS_ARE_CONTIGUOUS: tl.constexpr
785
+ ):
786
+ if BLOCKS_ARE_CONTIGUOUS:
787
+ return BLOCK
788
+ cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
789
+ cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
790
+ next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
791
+ needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
792
+ jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
793
+ offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
794
+ return offset
795
+
796
+ @triton.jit
797
+ def get_bounded_indices(indices, max_len=None):
798
+ return indices % max_len if max_len is not None else indices
799
+
800
+ @triton.jit
801
+ def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
802
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
803
+ return tl.load(block_ptr)
804
+ elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
805
+ return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
806
+ elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
807
+ return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
808
+ else:
809
+ return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
810
+
811
+ @triton.jit
812
+ def load_checked_2d(
813
+ ptr,
814
+ offs_m,
815
+ offs_n,
816
+ stride_m,
817
+ stride_n,
818
+ IS_DIVISIBLE_M: tl.constexpr,
819
+ IS_DIVISIBLE_N: tl.constexpr,
820
+ M_LEN: tl.constexpr,
821
+ N_LEN: tl.constexpr,
822
+ ):
823
+ # Calculate final pointer if strides are provided
824
+ if stride_m is not None and stride_n is not None:
825
+ ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
826
+
827
+ # Handle all masking cases
828
+ if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
829
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
830
+ elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
831
+ return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
832
+ elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
833
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
834
+ else: # Both divisible
835
+ return tl.load(ptr)
SpecForge-ext/cache/compiled_kernels/4h/9e00c3cbd4f3ffea506c2d972effa4f5d1a03b1819fbd2068ce6d04ad21a37d7.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 512, "num_warps": 8, "num_stages": 1, "configs_hash": "3ca5c3e34d35093f3c9ab2829a9faeebad5e61c4ca13d5ed6053d7b71ce60d5a", "found_by_coordesc": false, "time_taken_ms": 21, "triton_cache_hash": "Z2RWAHMO7VUWQKIIRA5A46JYV2SEXHWLKREQM7TOP6VGUWDXAYAQ"}
SpecForge-ext/cache/compiled_kernels/4h/c4hrpftpfto2n4yelfxmq5tawsfst2z5xq7othxvdoymqaudsvcw.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.pointwise(
11
+ size_hints={'x': 4194304},
12
+ filename=__file__,
13
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'ks4': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
15
+ min_elem_per_thread=0
16
+ )
17
+ @triton.jit
18
+ def triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, ks4, xnumel, XBLOCK : tl.constexpr):
19
+ xoffset = tl.program_id(0) * XBLOCK
20
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
21
+ xmask = xindex < xnumel
22
+ x4 = xindex
23
+ x2 = ((xindex // ks0) % ks1)
24
+ x0 = (xindex % ks3)
25
+ x5 = xindex // ks3
26
+ tmp0 = tl.load(in_ptr0 + (x4), xmask, eviction_policy='evict_last').to(tl.float32)
27
+ tmp1 = tl.load(in_ptr1 + (x2), xmask, eviction_policy='evict_last')
28
+ tmp2 = ks2
29
+ tmp3 = tmp1 + tmp2
30
+ tmp4 = tmp1 < 0
31
+ tmp5 = tl.where(tmp4, tmp3, tmp1)
32
+ tl.device_assert(((0 <= tmp5) & (tmp5 < ks2)) | ~(xmask), "index out of bounds: 0 <= tmp5 < ks2")
33
+ tmp7 = tl.load(in_ptr2 + (x0 + ks3*tmp5), xmask, eviction_policy='evict_last').to(tl.float32)
34
+ tmp8 = tmp0 * tmp7
35
+ tmp9 = x0
36
+ tmp10 = tl.full([1], 0, tl.int64)
37
+ tmp11 = tmp9 >= tmp10
38
+ tmp12 = ks3 + (-1)*(ks3 // 2)
39
+ tmp13 = tmp9 < tmp12
40
+ tmp14 = tl.load(in_ptr0 + (ks3*x5 + (ks3 // 2) + (x0)), tmp13 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
41
+ tmp15 = -tmp14
42
+ tmp16 = tl.full(tmp15.shape, 0.0, tmp15.dtype)
43
+ tmp17 = tl.where(tmp13, tmp15, tmp16)
44
+ tmp18 = tmp9 >= tmp12
45
+ tmp19 = ks3
46
+ tmp20 = tmp9 < tmp19
47
+ tmp21 = tl.load(in_ptr0 + (ks3*x5 + (x0 + ((-1)*ks3) + (ks3 // 2))), tmp18 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
48
+ tmp22 = tl.where(tmp13, tmp17, tmp21)
49
+ tmp23 = ks4
50
+ tmp24 = tmp1 + tmp23
51
+ tmp25 = tl.where(tmp4, tmp24, tmp1)
52
+ tl.device_assert(((0 <= tmp25) & (tmp25 < ks4)) | ~(xmask), "index out of bounds: 0 <= tmp25 < ks4")
53
+ tmp27 = tl.load(in_ptr3 + (x0 + ks3*tmp25), xmask, eviction_policy='evict_last').to(tl.float32)
54
+ tmp28 = tmp22 * tmp27
55
+ tmp29 = tmp8 + tmp28
56
+ tl.store(out_ptr0 + (x4), tmp29, xmask)
SpecForge-ext/cache/compiled_kernels/4k/30a0e09dbdf44769796e9e261da2a9dcbfc798ae7811e19f9adc033f960f3fae.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 1, "num_warps": 2, "num_stages": 1, "configs_hash": "b6ac5ef64fddcad8fc8d2c05fa12424871fd9baa5a4158ff38ecebbafb55a4b1", "found_by_coordesc": false, "time_taken_ms": 25, "triton_cache_hash": "G2LU7LIHIOEHQSWVLFBJATACJ76YHM672CUBUDGJGAJUEQVWVOFQ"}
SpecForge-ext/cache/compiled_kernels/4k/c4kzcehfveyvvtlnmx5jh5naezqnmtz2ubxuawsucb27r43j5yfa.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.persistent_reduction(
11
+ size_hints={'x': 256, 'r0_': 16},
12
+ reduction_hint=ReductionHint.DEFAULT,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr2': '*i32', 'out_ptr3': '*i32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
16
+ )
17
+ @triton.jit
18
+ def triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3(in_ptr0, out_ptr2, out_ptr3, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr):
19
+ r0_numel = 16
20
+ R0_BLOCK: tl.constexpr = 16
21
+ rnumel = r0_numel
22
+ RBLOCK: tl.constexpr = R0_BLOCK
23
+ xoffset = tl.program_id(0) * XBLOCK
24
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
25
+ xmask = xindex < xnumel
26
+ r0_index = tl.arange(0, R0_BLOCK)[None, :]
27
+ r0_offset = 0
28
+ r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
29
+ roffset = r0_offset
30
+ rindex = r0_index
31
+ r0_2 = r0_index
32
+ x0 = (xindex % ks0)
33
+ x1 = xindex // ks0
34
+ x3 = xindex
35
+ tmp0 = tl.load(in_ptr0 + (r0_2 + x0 + 16*x1 + ks0*r0_2 + 16*ks0*x1), xmask, eviction_policy='evict_last', other=0.0)
36
+ tmp1 = r0_2
37
+ tmp2 = tmp1.to(tl.int16)
38
+ tmp3 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
39
+ tmp4 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK])
40
+ tmp5, tmp6, = triton_helpers.sort_with_index(tmp3, tmp4, None, 1, stable=True, descending=True)
41
+ tmp7 = tmp0.to(tl.int64)
42
+ tmp8 = tl.broadcast_to(tmp7, [XBLOCK, R0_BLOCK])
43
+ tmp10 = tl.where(xmask, tmp8, 0)
44
+ tmp11 = tl.sum(tmp10, 1)[:, None].to(tl.int64)
45
+ tmp12 = tmp6.to(tl.int64)
46
+ tmp13 = tmp12.to(tl.int32)
47
+ tmp14 = tmp11.to(tl.int32)
48
+ tl.store(out_ptr2 + (r0_2 + 16*x0 + 16*x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp13, xmask)
49
+ tl.store(out_ptr3 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp14, xmask)
SpecForge-ext/cache/compiled_kernels/4m/c4mv34wib446qhr7sd5yhgc4mdneb7isnb6uitnbwvdgrbpgyf2s.py ADDED
@@ -0,0 +1,552 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+
9
+ @triton_heuristics.template(
10
+
11
+ num_stages=3,
12
+ num_warps=8,
13
+ triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'in_ptr9': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
15
+
16
+ )
17
+ @triton.jit
18
+ def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5):
19
+ PRESCALE_QK : tl.constexpr = False
20
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
21
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
22
+ WRITE_DQ : tl.constexpr = True
23
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
24
+ OUTPUT_MAX : tl.constexpr = False
25
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
26
+ IS_DIVISIBLE : tl.constexpr = False
27
+ SM_SCALE : tl.constexpr = 0.08838834764831843
28
+ GQA_SHARED_HEADS : tl.constexpr = 4
29
+ HAS_FULL_BLOCKS : tl.constexpr = True
30
+ QK_HEAD_DIM : tl.constexpr = 128
31
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
32
+ V_HEAD_DIM : tl.constexpr = 128
33
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
34
+ SAFE_HEAD_DIM : tl.constexpr = True
35
+ USE_TMA : tl.constexpr = False
36
+ BLOCK_M : tl.constexpr = 128
37
+ BLOCK_N : tl.constexpr = 64
38
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
39
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
40
+ INDEX_DTYPE : tl.constexpr = tl.int32
41
+ Q = arg_Q
42
+ K = arg_K
43
+ V = arg_V
44
+ LSE = arg_LSE
45
+ MAX = arg_MAX
46
+ KV_NUM_BLKS = arg_KV_NUM_BLKS
47
+ KV_IDX = arg_KV_IDX
48
+ FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
49
+ FULL_KV_IDX = arg_FULL_KV_IDX
50
+
51
+ # Sub notation for this kernel:
52
+ #
53
+ # Q: Query, K: Key, V: Value
54
+ # M: Number of queries, N: Number of keys/values, D: Model dimension
55
+ # QK_HEAD_DIM: The dimension of the query and key embeddings
56
+ # V_HEAD_DIM: The dimension of the value embeddings
57
+ # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head
58
+ # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
59
+ #
60
+ # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
61
+ # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
62
+ # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
63
+ # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
64
+ # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
65
+ #
66
+ # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad
67
+ #
68
+ # (Modifiable) Performance tuning options
69
+ # BLOCK_M: The thread block size across the seqlen dim of Q.
70
+ # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block.
71
+
72
+ # The below are kernel options that can be applied for certain score_mods,
73
+ # or involve a numerics vs. perf tradeoff
74
+ # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
75
+ # about 20% more numerical error, but slightly faster.
76
+ # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row
77
+ # is not masked out? If so, we can skip an extra safety check
78
+ # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are
79
+ # contiguous? If so, we don't need to do an indirect jump for every block
80
+
81
+ tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0)
82
+ tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0)
83
+
84
+ # Define strides of inputs
85
+ stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1
86
+ stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128*ks1, 128, 1
87
+ stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128*ks1, 128, 1
88
+
89
+ ZQ = 2
90
+ HQ = 32
91
+ Q_LEN = ks0
92
+ ZKV = 2
93
+ KV_LEN = ks1
94
+
95
+ MATMUL_PRECISION = Q.dtype.element_ty
96
+
97
+ q_start = tl.program_id(0).to(INDEX_DTYPE)
98
+ off_zq = tl.program_id(1).to(INDEX_DTYPE)
99
+ off_hq = tl.program_id(2).to(INDEX_DTYPE)
100
+
101
+ # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq.
102
+ # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0.
103
+ off_zkv = off_zq % ZKV
104
+ off_hkv = off_hq // GQA_SHARED_HEADS
105
+ off_g = off_hq % GQA_SHARED_HEADS
106
+
107
+ q_offset = off_zq * stride_qz + off_hq * stride_qh
108
+ k_offset = off_zkv * stride_kz + off_hkv * stride_kh
109
+ v_offset = off_zkv * stride_vz + off_hkv * stride_vh
110
+
111
+ Q = Q + q_offset
112
+ K = K + k_offset
113
+ V = V + v_offset
114
+
115
+ # Setting up the TMA descriptors for Q, K, V
116
+ desc_q = None
117
+ desc_k = None
118
+ desc_v = None
119
+
120
+ SPARSE_Z = 2
121
+ SPARSE_HQ = 1
122
+
123
+ sparse_idx_z = off_zq % SPARSE_Z
124
+ sparse_idx_hq = off_hq % SPARSE_HQ
125
+
126
+ SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M)
127
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
128
+
129
+ stride_kv_num_blks_h = ks2
130
+ stride_kv_idx_h = ks3*ks4
131
+ stride_kv_idx_m = ks4
132
+
133
+ # initialize pointer to m and l
134
+ m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
135
+ l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
136
+ acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
137
+
138
+ offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
139
+
140
+ # KV_IDX and KV_NUM_BLKS are always contiguous.
141
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq
142
+ sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE
143
+ sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950
144
+ offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
145
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
146
+ q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
147
+
148
+ # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
149
+ # We don't know anything "special" about these blocks, so we need to apply
150
+ # both score_mod and mask_mod to it
151
+ kv_indices = KV_IDX + sparse_kv_idx_offset
152
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
153
+ kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
154
+ block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
155
+
156
+
157
+ # K and V pointers will be passed directly to forward_inner
158
+
159
+ offs_n = kv_start + tl.arange(0, BLOCK_N)
160
+
161
+
162
+ acc, l_i, m_i = forward_inner(
163
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5,
164
+ q, K, V,
165
+ desc_k, desc_v, Q_LEN, KV_LEN,
166
+ acc, l_i, m_i,
167
+ off_zq, off_hq, offs_m[:, None], offs_n[None, :],
168
+ kv_start,
169
+ kv_indices, kv_num_blocks,
170
+ 0, block_n_end,
171
+ MATMUL_PRECISION,
172
+ stride_kk, stride_kn, stride_vn, stride_vk,
173
+ IS_FULL_BLOCKS=False,
174
+ )
175
+
176
+ # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
177
+ # We know these blocks are guaranteed to be "full", so we don't need to
178
+ # apply mask_mod to them - only score_mod
179
+ if HAS_FULL_BLOCKS:
180
+ # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
181
+ kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
182
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
183
+ kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
184
+ block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
185
+ # K and V pointers will be passed directly to forward_inner
186
+ offs_n = kv_start + tl.arange(0, BLOCK_N)
187
+
188
+ acc, l_i, m_i = forward_inner(
189
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5,
190
+ q, K, V,
191
+ desc_k, desc_v, Q_LEN, KV_LEN,
192
+ acc, l_i, m_i,
193
+ off_zq, off_hq, offs_m[:, None], offs_n[None, :],
194
+ kv_start,
195
+ kv_indices, kv_num_blocks,
196
+ 0, block_n_end,
197
+ MATMUL_PRECISION,
198
+ stride_kk, stride_kn, stride_vn, stride_vk,
199
+ IS_FULL_BLOCKS=True,
200
+ )
201
+
202
+
203
+ # [Note] Handle fully masked out rows:
204
+ # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf.
205
+ # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step
206
+ l_i = tl.where(l_i == 0.0, 1, l_i)
207
+
208
+ acc = acc / l_i[:, None]
209
+ idx_zq = tl.program_id(1).to(INDEX_DTYPE)
210
+ idx_hq = tl.program_id(2).to(INDEX_DTYPE)
211
+ idx_m = offs_m[:, None].to(INDEX_DTYPE)
212
+ idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE)
213
+
214
+ mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM)
215
+
216
+ tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED])
217
+ xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0
218
+ tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m + 4096*idx_zq*ks0, acc.shape)), acc, mask)
219
+
220
+ if OUTPUT_LOGSUMEXP:
221
+ off_hz = off_zq * HQ + off_hq
222
+ l_ptrs = LSE + off_hz * Q_LEN + offs_m
223
+ lse = m_i + tl.math.log2(l_i)
224
+ if IS_DIVISIBLE:
225
+ tl.store(l_ptrs, lse)
226
+ else:
227
+ tl.store(l_ptrs, lse, mask=offs_m < Q_LEN)
228
+
229
+ if OUTPUT_MAX:
230
+ off_hz = off_zq * HQ + off_hq
231
+ max_ptrs = MAX + off_hz * Q_LEN + offs_m
232
+ if IS_DIVISIBLE:
233
+ tl.store(max_ptrs, m_i)
234
+ else:
235
+ tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN)
236
+
237
+
238
+ # Utility triton funcs
239
+ @triton.jit
240
+ def get_offset_for_next_block(
241
+ loop_iter, col_indices, total_blocks,
242
+ SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
243
+ BLOCKS_ARE_CONTIGUOUS: tl.constexpr
244
+ ):
245
+ if BLOCKS_ARE_CONTIGUOUS:
246
+ return BLOCK
247
+ cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
248
+ cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
249
+ next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
250
+ needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
251
+ jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
252
+ offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
253
+ return offset
254
+
255
+ @triton.jit
256
+ def get_bounded_indices(indices, max_len=None):
257
+ return indices % max_len if max_len is not None else indices
258
+
259
+ @triton.jit
260
+ def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
261
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
262
+ return tl.load(block_ptr)
263
+ elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
264
+ return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
265
+ elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
266
+ return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
267
+ else:
268
+ return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
269
+
270
+ @triton.jit
271
+ def load_checked_2d(
272
+ ptr,
273
+ offs_m,
274
+ offs_n,
275
+ stride_m,
276
+ stride_n,
277
+ IS_DIVISIBLE_M: tl.constexpr,
278
+ IS_DIVISIBLE_N: tl.constexpr,
279
+ M_LEN: tl.constexpr,
280
+ N_LEN: tl.constexpr,
281
+ ):
282
+ # Calculate final pointer if strides are provided
283
+ if stride_m is not None and stride_n is not None:
284
+ ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
285
+
286
+ # Handle all masking cases
287
+ if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
288
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
289
+ elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
290
+ return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
291
+ elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
292
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
293
+ else: # Both divisible
294
+ return tl.load(ptr)
295
+
296
+
297
+ # Common Imports
298
+ @triton.jit
299
+ def forward_block_mn(
300
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5,
301
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
302
+ # accumulated values
303
+ acc, l_i, m_i,
304
+ # Offsets
305
+ off_z, off_h, offs_m, offs_n,
306
+ # Offsets needed for TMA loads
307
+ kv_start,
308
+ kv_offset,
309
+ MATMUL_PRECISION, RCP_LN2,
310
+ # Strides for K and V
311
+ stride_kk, stride_kn, stride_vn, stride_vk,
312
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
313
+
314
+ ):
315
+ # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
316
+ PRESCALE_QK : tl.constexpr = False
317
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
318
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
319
+ WRITE_DQ : tl.constexpr = True
320
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
321
+ OUTPUT_MAX : tl.constexpr = False
322
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
323
+ IS_DIVISIBLE : tl.constexpr = False
324
+ SM_SCALE : tl.constexpr = 0.08838834764831843
325
+ GQA_SHARED_HEADS : tl.constexpr = 4
326
+ HAS_FULL_BLOCKS : tl.constexpr = True
327
+ QK_HEAD_DIM : tl.constexpr = 128
328
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
329
+ V_HEAD_DIM : tl.constexpr = 128
330
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
331
+ SAFE_HEAD_DIM : tl.constexpr = True
332
+ USE_TMA : tl.constexpr = False
333
+ BLOCK_M : tl.constexpr = 128
334
+ BLOCK_N : tl.constexpr = 64
335
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
336
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
337
+ INDEX_DTYPE : tl.constexpr = tl.int32
338
+
339
+
340
+ # -- load k --
341
+ # NB reversed order to since K is transposed
342
+ kv_base_offset = kv_start + kv_offset
343
+
344
+ # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N]
345
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
346
+ offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N)
347
+ k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
348
+
349
+ k = tl.trans(k)
350
+ # -- compute qk ---
351
+ qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2.
352
+ if not PRESCALE_QK:
353
+ qk *= SM_SCALE
354
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
355
+ # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements,
356
+ # which is larger than the actual number of elements. To avoid access memory out of bound,
357
+ # we need to mask out the elements that are out of Q_LEN & KV_LEN.
358
+ m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None)
359
+ n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None)
360
+
361
+ tmp0 = (qk)
362
+ post_mod_scores = tmp0
363
+
364
+
365
+ if CHECK_BLOCK_BOUNDARY:
366
+ # Mask out the elements that are out of the KV_LEN for non divisible seqlen.
367
+ post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf"))
368
+
369
+ if not IS_FULL_BLOCKS:
370
+ tmp1 = tl.full([1], False, tl.int1)
371
+ tmp2 = (m)
372
+ tmp3 = (n)
373
+ tmp4 = tmp2 >= tmp3
374
+ tmp5 = tmp3.to(tl.int64)
375
+ tmp6 = (off_z)
376
+ tmp7 = tl.load(in_ptr9 + tmp6)
377
+ tmp8 = tmp5 < tmp7
378
+ tmp9 = tmp2.to(tl.int64)
379
+ tmp10 = tmp9 < tmp7
380
+ tmp11 = tmp8 & tmp10
381
+ tmp12 = tmp4 & tmp11
382
+ tmp13 = tmp1 | tmp12
383
+ tmp14 = ks5
384
+ tmp15 = tmp3 >= tmp14
385
+ tmp16 = (tmp3 % tmp14)
386
+ tmp17 = tl.full([1], 0, tl.int32)
387
+ tmp18 = tmp16 != tmp17
388
+ tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0
389
+ tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0
390
+ tmp21 = tmp19 != tmp20
391
+ tmp22 = tmp18 & tmp21
392
+ tmp23 = tmp16 + tmp14
393
+ tmp24 = tl.where(tmp22, tmp23, tmp16)
394
+ tmp25 = tmp24.to(tl.int64)
395
+ tmp26 = tmp25 < tmp7
396
+ tmp27 = tmp15 & tmp26
397
+ tmp28 = tmp3 - tmp2
398
+ tmp29 = (tmp28 % tmp14)
399
+ tmp30 = tmp29 != tmp17
400
+ tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0
401
+ tmp32 = tmp31 != tmp20
402
+ tmp33 = tmp30 & tmp32
403
+ tmp34 = tmp29 + tmp14
404
+ tmp35 = tl.where(tmp33, tmp34, tmp29)
405
+ tmp36 = tmp35 == tmp17
406
+ tmp37 = tmp27 & tmp36
407
+ tmp38 = tmp13 | tmp37
408
+ mask_mod_output = tmp38
409
+
410
+
411
+ if CHECK_BLOCK_BOUNDARY:
412
+ mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False)
413
+ # apply mask for partially unmasked blocks
414
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
415
+
416
+ if not PRESCALE_QK:
417
+ post_mod_scores *= RCP_LN2
418
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
419
+
420
+ # -- compute scaling constant ---
421
+ m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1))
422
+ if not ROWS_GUARANTEED_SAFE:
423
+ masked_out_rows = (m_ij == float("-inf"))
424
+ m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
425
+ else:
426
+ m_ij_masked = m_ij
427
+
428
+ alpha = tl.math.exp2(m_i - m_ij_masked)
429
+ p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None])
430
+
431
+ # NB: l_i update is pulled up here since it's a bit faster
432
+ # NB: For headdim=256, it's faster to move it back down to after m_i =
433
+ # m_ij
434
+ l_i = l_i * alpha + tl.sum(p, 1)
435
+ # # -- scale and update acc --
436
+ acc = acc * alpha[:, None]
437
+ # Calculate offsets for V loading - reuse kv_base_offset from K loading
438
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
439
+ v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
440
+ acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION)
441
+
442
+ # -- update m_i
443
+ m_i = m_ij
444
+
445
+ return acc, l_i, m_i
446
+
447
+ @triton.jit
448
+ def forward_inner(
449
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5,
450
+ q, K, V,
451
+ desc_k, desc_v, Q_LEN, KV_LEN,
452
+ # accumulated values
453
+ acc, l_i, m_i,
454
+ # Offsets used as inputs to score_mod & mask_mod
455
+ # of size [BLOCK_M, BLOCK_N] or scalar.
456
+ off_z, off_h, offs_m, offs_n,
457
+ # Offsets needed for TMA loads
458
+ kv_start,
459
+ # blocksparse data
460
+ kv_indices, kv_num_blocks,
461
+ # start kv and end kv block
462
+ block_n_start, block_n_end,
463
+ MATMUL_PRECISION,
464
+ # Strides for K and V
465
+ stride_kk, stride_kn, stride_vn, stride_vk,
466
+ IS_FULL_BLOCKS,
467
+ ):
468
+ # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
469
+ PRESCALE_QK : tl.constexpr = False
470
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
471
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
472
+ WRITE_DQ : tl.constexpr = True
473
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
474
+ OUTPUT_MAX : tl.constexpr = False
475
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
476
+ IS_DIVISIBLE : tl.constexpr = False
477
+ SM_SCALE : tl.constexpr = 0.08838834764831843
478
+ GQA_SHARED_HEADS : tl.constexpr = 4
479
+ HAS_FULL_BLOCKS : tl.constexpr = True
480
+ QK_HEAD_DIM : tl.constexpr = 128
481
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
482
+ V_HEAD_DIM : tl.constexpr = 128
483
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
484
+ SAFE_HEAD_DIM : tl.constexpr = True
485
+ USE_TMA : tl.constexpr = False
486
+ BLOCK_M : tl.constexpr = 128
487
+ BLOCK_N : tl.constexpr = 64
488
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
489
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
490
+ INDEX_DTYPE : tl.constexpr = tl.int32
491
+
492
+
493
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
494
+ RCP_LN2: tl.constexpr = 1.44269504
495
+
496
+ if PRESCALE_QK:
497
+ q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
498
+
499
+ kv_offset = 0
500
+
501
+ # loop over k, v and update accumulator until block_n_end
502
+ for start_n in range(block_n_start, block_n_end):
503
+ # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention.
504
+ if IS_DIVISIBLE:
505
+ acc, l_i, m_i = forward_block_mn(
506
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5,
507
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
508
+ # accumulated values
509
+ acc, l_i, m_i,
510
+ # Offsets
511
+ off_z, off_h, offs_m, offs_n,
512
+ # Offsets needed for TMA loads
513
+ kv_start,
514
+ kv_offset,
515
+ MATMUL_PRECISION, RCP_LN2,
516
+ # Strides for K and V
517
+ stride_kk, stride_kn, stride_vn, stride_vk,
518
+ IS_FULL_BLOCKS,
519
+ )
520
+ else:
521
+ # Benchmark shows even we applied mod & mask to each block for non divisible seqlen,
522
+ # it's on par or slightly faster than only applying to the last block in fwd.
523
+ # However, we choose different strategy for bwd, where we only apply mod & mask
524
+ # to the last block because it's faster a lot.
525
+ acc, l_i, m_i = forward_block_mn(
526
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5,
527
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
528
+ # accumulated values
529
+ acc, l_i, m_i,
530
+ # Offsets
531
+ off_z, off_h, offs_m, offs_n,
532
+ # Offsets needed for TMA loads
533
+ kv_start,
534
+ kv_offset,
535
+ MATMUL_PRECISION, RCP_LN2,
536
+ # Strides for K and V
537
+ stride_kk, stride_kn, stride_vn, stride_vk,
538
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
539
+ )
540
+
541
+
542
+
543
+ offset = get_offset_for_next_block(
544
+ start_n, kv_indices, kv_num_blocks,
545
+ SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS
546
+ )
547
+
548
+ offs_n = offs_n + offset
549
+ kv_offset += offset
550
+
551
+
552
+ return acc, l_i, m_i
SpecForge-ext/cache/compiled_kernels/4u/c4uf4o6eypfpqr4isgii4opqr5i3brobwecljte7sqvztk2kyafz.py ADDED
@@ -0,0 +1,552 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+
9
+ @triton_heuristics.template(
10
+
11
+ num_stages=3,
12
+ num_warps=8,
13
+ triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'in_ptr9': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
15
+
16
+ )
17
+ @triton.jit
18
+ def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1):
19
+ PRESCALE_QK : tl.constexpr = False
20
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
21
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
22
+ WRITE_DQ : tl.constexpr = True
23
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
24
+ OUTPUT_MAX : tl.constexpr = False
25
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
26
+ IS_DIVISIBLE : tl.constexpr = False
27
+ SM_SCALE : tl.constexpr = 0.08838834764831843
28
+ GQA_SHARED_HEADS : tl.constexpr = 4
29
+ HAS_FULL_BLOCKS : tl.constexpr = True
30
+ QK_HEAD_DIM : tl.constexpr = 128
31
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
32
+ V_HEAD_DIM : tl.constexpr = 128
33
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
34
+ SAFE_HEAD_DIM : tl.constexpr = True
35
+ USE_TMA : tl.constexpr = False
36
+ BLOCK_M : tl.constexpr = 128
37
+ BLOCK_N : tl.constexpr = 64
38
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
39
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
40
+ INDEX_DTYPE : tl.constexpr = tl.int32
41
+ Q = arg_Q
42
+ K = arg_K
43
+ V = arg_V
44
+ LSE = arg_LSE
45
+ MAX = arg_MAX
46
+ KV_NUM_BLKS = arg_KV_NUM_BLKS
47
+ KV_IDX = arg_KV_IDX
48
+ FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
49
+ FULL_KV_IDX = arg_FULL_KV_IDX
50
+
51
+ # Sub notation for this kernel:
52
+ #
53
+ # Q: Query, K: Key, V: Value
54
+ # M: Number of queries, N: Number of keys/values, D: Model dimension
55
+ # QK_HEAD_DIM: The dimension of the query and key embeddings
56
+ # V_HEAD_DIM: The dimension of the value embeddings
57
+ # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head
58
+ # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
59
+ #
60
+ # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
61
+ # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
62
+ # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
63
+ # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
64
+ # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
65
+ #
66
+ # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad
67
+ #
68
+ # (Modifiable) Performance tuning options
69
+ # BLOCK_M: The thread block size across the seqlen dim of Q.
70
+ # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block.
71
+
72
+ # The below are kernel options that can be applied for certain score_mods,
73
+ # or involve a numerics vs. perf tradeoff
74
+ # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
75
+ # about 20% more numerical error, but slightly faster.
76
+ # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row
77
+ # is not masked out? If so, we can skip an extra safety check
78
+ # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are
79
+ # contiguous? If so, we don't need to do an indirect jump for every block
80
+
81
+ tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0)
82
+ tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0)
83
+
84
+ # Define strides of inputs
85
+ stride_qz, stride_qh, stride_qm, stride_qk = 8388608, 128, 4096, 1
86
+ stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks0, 128*ks0, 128, 1
87
+ stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks0, 128*ks0, 128, 1
88
+
89
+ ZQ = 2
90
+ HQ = 32
91
+ Q_LEN = 2048
92
+ ZKV = 2
93
+ KV_LEN = ks0
94
+
95
+ MATMUL_PRECISION = Q.dtype.element_ty
96
+
97
+ q_start = tl.program_id(0).to(INDEX_DTYPE)
98
+ off_zq = tl.program_id(1).to(INDEX_DTYPE)
99
+ off_hq = tl.program_id(2).to(INDEX_DTYPE)
100
+
101
+ # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq.
102
+ # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0.
103
+ off_zkv = off_zq % ZKV
104
+ off_hkv = off_hq // GQA_SHARED_HEADS
105
+ off_g = off_hq % GQA_SHARED_HEADS
106
+
107
+ q_offset = off_zq * stride_qz + off_hq * stride_qh
108
+ k_offset = off_zkv * stride_kz + off_hkv * stride_kh
109
+ v_offset = off_zkv * stride_vz + off_hkv * stride_vh
110
+
111
+ Q = Q + q_offset
112
+ K = K + k_offset
113
+ V = V + v_offset
114
+
115
+ # Setting up the TMA descriptors for Q, K, V
116
+ desc_q = None
117
+ desc_k = None
118
+ desc_v = None
119
+
120
+ SPARSE_Z = 2
121
+ SPARSE_HQ = 1
122
+
123
+ sparse_idx_z = off_zq % SPARSE_Z
124
+ sparse_idx_hq = off_hq % SPARSE_HQ
125
+
126
+ SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M)
127
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
128
+
129
+ stride_kv_num_blks_h = 16
130
+ stride_kv_idx_h = 16*ks1
131
+ stride_kv_idx_m = ks1
132
+
133
+ # initialize pointer to m and l
134
+ m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
135
+ l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
136
+ acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
137
+
138
+ offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
139
+
140
+ # KV_IDX and KV_NUM_BLKS are always contiguous.
141
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq
142
+ sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE
143
+ sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950
144
+ offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
145
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
146
+ q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
147
+
148
+ # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
149
+ # We don't know anything "special" about these blocks, so we need to apply
150
+ # both score_mod and mask_mod to it
151
+ kv_indices = KV_IDX + sparse_kv_idx_offset
152
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
153
+ kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
154
+ block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
155
+
156
+
157
+ # K and V pointers will be passed directly to forward_inner
158
+
159
+ offs_n = kv_start + tl.arange(0, BLOCK_N)
160
+
161
+
162
+ acc, l_i, m_i = forward_inner(
163
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1,
164
+ q, K, V,
165
+ desc_k, desc_v, Q_LEN, KV_LEN,
166
+ acc, l_i, m_i,
167
+ off_zq, off_hq, offs_m[:, None], offs_n[None, :],
168
+ kv_start,
169
+ kv_indices, kv_num_blocks,
170
+ 0, block_n_end,
171
+ MATMUL_PRECISION,
172
+ stride_kk, stride_kn, stride_vn, stride_vk,
173
+ IS_FULL_BLOCKS=False,
174
+ )
175
+
176
+ # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
177
+ # We know these blocks are guaranteed to be "full", so we don't need to
178
+ # apply mask_mod to them - only score_mod
179
+ if HAS_FULL_BLOCKS:
180
+ # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
181
+ kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
182
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
183
+ kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
184
+ block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
185
+ # K and V pointers will be passed directly to forward_inner
186
+ offs_n = kv_start + tl.arange(0, BLOCK_N)
187
+
188
+ acc, l_i, m_i = forward_inner(
189
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1,
190
+ q, K, V,
191
+ desc_k, desc_v, Q_LEN, KV_LEN,
192
+ acc, l_i, m_i,
193
+ off_zq, off_hq, offs_m[:, None], offs_n[None, :],
194
+ kv_start,
195
+ kv_indices, kv_num_blocks,
196
+ 0, block_n_end,
197
+ MATMUL_PRECISION,
198
+ stride_kk, stride_kn, stride_vn, stride_vk,
199
+ IS_FULL_BLOCKS=True,
200
+ )
201
+
202
+
203
+ # [Note] Handle fully masked out rows:
204
+ # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf.
205
+ # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step
206
+ l_i = tl.where(l_i == 0.0, 1, l_i)
207
+
208
+ acc = acc / l_i[:, None]
209
+ idx_zq = tl.program_id(1).to(INDEX_DTYPE)
210
+ idx_hq = tl.program_id(2).to(INDEX_DTYPE)
211
+ idx_m = offs_m[:, None].to(INDEX_DTYPE)
212
+ idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE)
213
+
214
+ mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM)
215
+
216
+ tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED])
217
+ xindex = idx_d + 128*idx_m + 262144*idx_hq + 8388608*idx_zq
218
+ tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m + 8388608*idx_zq, acc.shape)), acc, mask)
219
+
220
+ if OUTPUT_LOGSUMEXP:
221
+ off_hz = off_zq * HQ + off_hq
222
+ l_ptrs = LSE + off_hz * Q_LEN + offs_m
223
+ lse = m_i + tl.math.log2(l_i)
224
+ if IS_DIVISIBLE:
225
+ tl.store(l_ptrs, lse)
226
+ else:
227
+ tl.store(l_ptrs, lse, mask=offs_m < Q_LEN)
228
+
229
+ if OUTPUT_MAX:
230
+ off_hz = off_zq * HQ + off_hq
231
+ max_ptrs = MAX + off_hz * Q_LEN + offs_m
232
+ if IS_DIVISIBLE:
233
+ tl.store(max_ptrs, m_i)
234
+ else:
235
+ tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN)
236
+
237
+
238
+ # Utility triton funcs
239
+ @triton.jit
240
+ def get_offset_for_next_block(
241
+ loop_iter, col_indices, total_blocks,
242
+ SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
243
+ BLOCKS_ARE_CONTIGUOUS: tl.constexpr
244
+ ):
245
+ if BLOCKS_ARE_CONTIGUOUS:
246
+ return BLOCK
247
+ cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
248
+ cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
249
+ next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
250
+ needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
251
+ jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
252
+ offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
253
+ return offset
254
+
255
+ @triton.jit
256
+ def get_bounded_indices(indices, max_len=None):
257
+ return indices % max_len if max_len is not None else indices
258
+
259
+ @triton.jit
260
+ def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
261
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
262
+ return tl.load(block_ptr)
263
+ elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
264
+ return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
265
+ elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
266
+ return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
267
+ else:
268
+ return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
269
+
270
+ @triton.jit
271
+ def load_checked_2d(
272
+ ptr,
273
+ offs_m,
274
+ offs_n,
275
+ stride_m,
276
+ stride_n,
277
+ IS_DIVISIBLE_M: tl.constexpr,
278
+ IS_DIVISIBLE_N: tl.constexpr,
279
+ M_LEN: tl.constexpr,
280
+ N_LEN: tl.constexpr,
281
+ ):
282
+ # Calculate final pointer if strides are provided
283
+ if stride_m is not None and stride_n is not None:
284
+ ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
285
+
286
+ # Handle all masking cases
287
+ if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
288
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
289
+ elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
290
+ return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
291
+ elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
292
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
293
+ else: # Both divisible
294
+ return tl.load(ptr)
295
+
296
+
297
+ # Common Imports
298
+ @triton.jit
299
+ def forward_block_mn(
300
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1,
301
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
302
+ # accumulated values
303
+ acc, l_i, m_i,
304
+ # Offsets
305
+ off_z, off_h, offs_m, offs_n,
306
+ # Offsets needed for TMA loads
307
+ kv_start,
308
+ kv_offset,
309
+ MATMUL_PRECISION, RCP_LN2,
310
+ # Strides for K and V
311
+ stride_kk, stride_kn, stride_vn, stride_vk,
312
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
313
+
314
+ ):
315
+ # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
316
+ PRESCALE_QK : tl.constexpr = False
317
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
318
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
319
+ WRITE_DQ : tl.constexpr = True
320
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
321
+ OUTPUT_MAX : tl.constexpr = False
322
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
323
+ IS_DIVISIBLE : tl.constexpr = False
324
+ SM_SCALE : tl.constexpr = 0.08838834764831843
325
+ GQA_SHARED_HEADS : tl.constexpr = 4
326
+ HAS_FULL_BLOCKS : tl.constexpr = True
327
+ QK_HEAD_DIM : tl.constexpr = 128
328
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
329
+ V_HEAD_DIM : tl.constexpr = 128
330
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
331
+ SAFE_HEAD_DIM : tl.constexpr = True
332
+ USE_TMA : tl.constexpr = False
333
+ BLOCK_M : tl.constexpr = 128
334
+ BLOCK_N : tl.constexpr = 64
335
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
336
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
337
+ INDEX_DTYPE : tl.constexpr = tl.int32
338
+
339
+
340
+ # -- load k --
341
+ # NB reversed order to since K is transposed
342
+ kv_base_offset = kv_start + kv_offset
343
+
344
+ # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N]
345
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
346
+ offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N)
347
+ k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
348
+
349
+ k = tl.trans(k)
350
+ # -- compute qk ---
351
+ qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2.
352
+ if not PRESCALE_QK:
353
+ qk *= SM_SCALE
354
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
355
+ # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements,
356
+ # which is larger than the actual number of elements. To avoid access memory out of bound,
357
+ # we need to mask out the elements that are out of Q_LEN & KV_LEN.
358
+ m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None)
359
+ n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None)
360
+
361
+ tmp0 = (qk)
362
+ post_mod_scores = tmp0
363
+
364
+
365
+ if CHECK_BLOCK_BOUNDARY:
366
+ # Mask out the elements that are out of the KV_LEN for non divisible seqlen.
367
+ post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf"))
368
+
369
+ if not IS_FULL_BLOCKS:
370
+ tmp1 = tl.full([1], False, tl.int1)
371
+ tmp2 = (m)
372
+ tmp3 = (n)
373
+ tmp4 = tmp2 >= tmp3
374
+ tmp5 = tmp3.to(tl.int64)
375
+ tmp6 = (off_z)
376
+ tmp7 = tl.load(in_ptr9 + tmp6)
377
+ tmp8 = tmp5 < tmp7
378
+ tmp9 = tmp2.to(tl.int64)
379
+ tmp10 = tmp9 < tmp7
380
+ tmp11 = tmp8 & tmp10
381
+ tmp12 = tmp4 & tmp11
382
+ tmp13 = tmp1 | tmp12
383
+ tmp14 = tl.full([1], 2048, tl.int32)
384
+ tmp15 = tmp3 >= tmp14
385
+ tmp16 = (tmp3 % tmp14)
386
+ tmp17 = tl.full([1], 0, tl.int32)
387
+ tmp18 = tmp16 != tmp17
388
+ tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0
389
+ tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0
390
+ tmp21 = tmp19 != tmp20
391
+ tmp22 = tmp18 & tmp21
392
+ tmp23 = tmp16 + tmp14
393
+ tmp24 = tl.where(tmp22, tmp23, tmp16)
394
+ tmp25 = tmp24.to(tl.int64)
395
+ tmp26 = tmp25 < tmp7
396
+ tmp27 = tmp15 & tmp26
397
+ tmp28 = tmp3 - tmp2
398
+ tmp29 = (tmp28 % tmp14)
399
+ tmp30 = tmp29 != tmp17
400
+ tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0
401
+ tmp32 = tmp31 != tmp20
402
+ tmp33 = tmp30 & tmp32
403
+ tmp34 = tmp29 + tmp14
404
+ tmp35 = tl.where(tmp33, tmp34, tmp29)
405
+ tmp36 = tmp35 == tmp17
406
+ tmp37 = tmp27 & tmp36
407
+ tmp38 = tmp13 | tmp37
408
+ mask_mod_output = tmp38
409
+
410
+
411
+ if CHECK_BLOCK_BOUNDARY:
412
+ mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False)
413
+ # apply mask for partially unmasked blocks
414
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
415
+
416
+ if not PRESCALE_QK:
417
+ post_mod_scores *= RCP_LN2
418
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
419
+
420
+ # -- compute scaling constant ---
421
+ m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1))
422
+ if not ROWS_GUARANTEED_SAFE:
423
+ masked_out_rows = (m_ij == float("-inf"))
424
+ m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
425
+ else:
426
+ m_ij_masked = m_ij
427
+
428
+ alpha = tl.math.exp2(m_i - m_ij_masked)
429
+ p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None])
430
+
431
+ # NB: l_i update is pulled up here since it's a bit faster
432
+ # NB: For headdim=256, it's faster to move it back down to after m_i =
433
+ # m_ij
434
+ l_i = l_i * alpha + tl.sum(p, 1)
435
+ # # -- scale and update acc --
436
+ acc = acc * alpha[:, None]
437
+ # Calculate offsets for V loading - reuse kv_base_offset from K loading
438
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
439
+ v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
440
+ acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION)
441
+
442
+ # -- update m_i
443
+ m_i = m_ij
444
+
445
+ return acc, l_i, m_i
446
+
447
+ @triton.jit
448
+ def forward_inner(
449
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1,
450
+ q, K, V,
451
+ desc_k, desc_v, Q_LEN, KV_LEN,
452
+ # accumulated values
453
+ acc, l_i, m_i,
454
+ # Offsets used as inputs to score_mod & mask_mod
455
+ # of size [BLOCK_M, BLOCK_N] or scalar.
456
+ off_z, off_h, offs_m, offs_n,
457
+ # Offsets needed for TMA loads
458
+ kv_start,
459
+ # blocksparse data
460
+ kv_indices, kv_num_blocks,
461
+ # start kv and end kv block
462
+ block_n_start, block_n_end,
463
+ MATMUL_PRECISION,
464
+ # Strides for K and V
465
+ stride_kk, stride_kn, stride_vn, stride_vk,
466
+ IS_FULL_BLOCKS,
467
+ ):
468
+ # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
469
+ PRESCALE_QK : tl.constexpr = False
470
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
471
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
472
+ WRITE_DQ : tl.constexpr = True
473
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
474
+ OUTPUT_MAX : tl.constexpr = False
475
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
476
+ IS_DIVISIBLE : tl.constexpr = False
477
+ SM_SCALE : tl.constexpr = 0.08838834764831843
478
+ GQA_SHARED_HEADS : tl.constexpr = 4
479
+ HAS_FULL_BLOCKS : tl.constexpr = True
480
+ QK_HEAD_DIM : tl.constexpr = 128
481
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
482
+ V_HEAD_DIM : tl.constexpr = 128
483
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
484
+ SAFE_HEAD_DIM : tl.constexpr = True
485
+ USE_TMA : tl.constexpr = False
486
+ BLOCK_M : tl.constexpr = 128
487
+ BLOCK_N : tl.constexpr = 64
488
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
489
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
490
+ INDEX_DTYPE : tl.constexpr = tl.int32
491
+
492
+
493
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
494
+ RCP_LN2: tl.constexpr = 1.44269504
495
+
496
+ if PRESCALE_QK:
497
+ q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
498
+
499
+ kv_offset = 0
500
+
501
+ # loop over k, v and update accumulator until block_n_end
502
+ for start_n in range(block_n_start, block_n_end):
503
+ # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention.
504
+ if IS_DIVISIBLE:
505
+ acc, l_i, m_i = forward_block_mn(
506
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1,
507
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
508
+ # accumulated values
509
+ acc, l_i, m_i,
510
+ # Offsets
511
+ off_z, off_h, offs_m, offs_n,
512
+ # Offsets needed for TMA loads
513
+ kv_start,
514
+ kv_offset,
515
+ MATMUL_PRECISION, RCP_LN2,
516
+ # Strides for K and V
517
+ stride_kk, stride_kn, stride_vn, stride_vk,
518
+ IS_FULL_BLOCKS,
519
+ )
520
+ else:
521
+ # Benchmark shows even we applied mod & mask to each block for non divisible seqlen,
522
+ # it's on par or slightly faster than only applying to the last block in fwd.
523
+ # However, we choose different strategy for bwd, where we only apply mod & mask
524
+ # to the last block because it's faster a lot.
525
+ acc, l_i, m_i = forward_block_mn(
526
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1,
527
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
528
+ # accumulated values
529
+ acc, l_i, m_i,
530
+ # Offsets
531
+ off_z, off_h, offs_m, offs_n,
532
+ # Offsets needed for TMA loads
533
+ kv_start,
534
+ kv_offset,
535
+ MATMUL_PRECISION, RCP_LN2,
536
+ # Strides for K and V
537
+ stride_kk, stride_kn, stride_vn, stride_vk,
538
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
539
+ )
540
+
541
+
542
+
543
+ offset = get_offset_for_next_block(
544
+ start_n, kv_indices, kv_num_blocks,
545
+ SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS
546
+ )
547
+
548
+ offs_n = offs_n + offset
549
+ kv_offset += offset
550
+
551
+
552
+ return acc, l_i, m_i
SpecForge-ext/cache/compiled_kernels/4u/c4uhrh7gjsy72in52pmmkpoiwetwjbked3nkrbcotbo4sj5bq7bi.py ADDED
@@ -0,0 +1,835 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+
9
+ @triton_heuristics.template(
10
+
11
+ num_stages=3,
12
+ num_warps=8,
13
+ triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
15
+
16
+ )
17
+ @triton.jit
18
+ def triton_flex_attention_backward(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3):
19
+ PRESCALE_QK : tl.constexpr = False
20
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
21
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
22
+ WRITE_DQ : tl.constexpr = True
23
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
24
+ OUTPUT_MAX : tl.constexpr = False
25
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
26
+ IS_DIVISIBLE : tl.constexpr = False
27
+ SM_SCALE : tl.constexpr = 0.08838834764831843
28
+ GQA_SHARED_HEADS : tl.constexpr = 4
29
+ HAS_FULL_BLOCKS : tl.constexpr = True
30
+ QK_HEAD_DIM : tl.constexpr = 128
31
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
32
+ V_HEAD_DIM : tl.constexpr = 128
33
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
34
+ SAFE_HEAD_DIM : tl.constexpr = True
35
+ BLOCK_M1 : tl.constexpr = 64
36
+ BLOCK_N1 : tl.constexpr = 128
37
+ BLOCK_M2 : tl.constexpr = 128
38
+ BLOCK_N2 : tl.constexpr = 64
39
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
40
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
41
+ INDEX_DTYPE : tl.constexpr = tl.int32
42
+ Q = arg_Q
43
+ K = arg_K
44
+ V = arg_V
45
+ LSE = arg_LSE
46
+ DELTA = arg_DELTA
47
+ DO = arg_DO
48
+ DQ = arg_DQ
49
+ DV = arg_DV
50
+ KV_NUM_BLKS = arg_KV_NUM_BLKS
51
+ KV_IDX = arg_KV_IDX
52
+ Q_NUM_BLKS = arg_Q_NUM_BLKS
53
+ Q_IDX = arg_Q_IDX
54
+ FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
55
+ FULL_KV_IDX = arg_FULL_KV_IDX
56
+ FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS
57
+ FULL_Q_IDX = arg_FULL_Q_IDX
58
+
59
+ # Sub notation for this kernel:
60
+ #
61
+ # Q: Query, K: Key, V: Value
62
+ # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype)
63
+ # DELTA: Precomputed sum(OUT*DO, axis=-1)
64
+ # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value
65
+ # DK: Derivative of Key, is the written to via the store_output call due to some limitations with
66
+ # inductor codegen
67
+ # M: Number of queries, N: Number of keys/values
68
+ # QK_HEAD_DIM: The dimension of the query and key embeddings
69
+ # V_HEAD_DIM: The dimension of the value embeddings
70
+ # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim
71
+ # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
72
+ # (Modifiable) Performance tuning options
73
+ # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block.
74
+ # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V.
75
+ # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q.
76
+ # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block.
77
+ #
78
+ # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
79
+ # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
80
+ # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
81
+ # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query.
82
+ # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query.
83
+ # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
84
+ # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
85
+ # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query.
86
+ # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query.
87
+
88
+ # The below are kernel options that can be applied for certain score_mods,
89
+ # or involve a numerics vs. perf tradeoff
90
+ # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
91
+ # about 20% more numerical error, but slightly faster.
92
+
93
+ # Define strides of inputs
94
+ stride_qz, stride_qh, stride_qm, stride_qd = 8388608, 128, 4096, 1
95
+ stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks0, 128*ks0, 128, 1
96
+ stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks0, 128*ks0, 128, 1
97
+ stride_doz, stride_doh, stride_dom, stride_dod = 8388608, 262144, 128, 1
98
+
99
+ stride_dqz, stride_dqh, stride_dqm, stride_dqd = 8388608, 128, 4096, 1
100
+ stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks0, 128*ks0, 128, 1
101
+
102
+ ZQ = 8
103
+ HQ = 32
104
+ HKV = 8
105
+ Q_LEN = 2048
106
+ ZKV = 8
107
+ KV_LEN = ks0
108
+
109
+ MATMUL_PRECISION = Q.dtype.element_ty
110
+
111
+ pid = tl.program_id(0).to(INDEX_DTYPE)
112
+ NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1)
113
+ NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2)
114
+
115
+ off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx
116
+ off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx
117
+ off_zkv = off_zq % ZKV # kv batch idx
118
+
119
+ SPARSE_Z = 8
120
+ SPARSE_HQ = 1
121
+
122
+ sparse_idx_z = off_zq % SPARSE_Z
123
+
124
+ k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64)
125
+ v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64)
126
+ # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
127
+ # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
128
+ dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64)
129
+
130
+ # offset K, V, DV pointers for batch/kv-head
131
+ K += k_adj
132
+ V += v_adj
133
+ DV += dv_adj
134
+
135
+ RCP_LN2 = 1.44269504
136
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
137
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
138
+
139
+ if pid >= NUM_KV_BLOCKS:
140
+ off_pid = pid - NUM_KV_BLOCKS
141
+ # THIS BLOCK DOES DQ
142
+ SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2)
143
+ SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
144
+ off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS
145
+ start_m2_block = off_pid % NUM_Q_BLOCKS
146
+ off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE
147
+ stride_kv_num_blks_h = 16
148
+ stride_kv_idx_h = 16*ks1
149
+ stride_kv_idx_m = ks1
150
+
151
+ sparse_idx_hq2 = off_hq2 % SPARSE_HQ
152
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2
153
+
154
+ sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask
155
+ sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950
156
+
157
+ # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
158
+ q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64)
159
+ do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64)
160
+ dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64)
161
+ off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64)
162
+
163
+ Q2 = Q + q_adj2
164
+ DO2 = DO + do_adj2
165
+ # TODO: This does not work if DQ is not the same layout as Q (for example,
166
+ # if Q is broadcasted)
167
+ DQ2 = DQ + dq_adj2
168
+ LSE2 = LSE + off_chz2
169
+ DELTA2 = DELTA + off_chz2
170
+
171
+ # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32)
172
+ dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
173
+
174
+ start_m2 = start_m2_block * BLOCK_M2
175
+ offs_m2 = start_m2 + tl.arange(0, BLOCK_M2)
176
+
177
+ # load Q and do: they stay in SRAM throughout the inner loop.
178
+ q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
179
+ do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
180
+
181
+ if PRESCALE_QK:
182
+ q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
183
+
184
+ if IS_DIVISIBLE:
185
+ Di = tl.load(DELTA2 + offs_m2)
186
+ lse = tl.load(LSE2 + offs_m2)
187
+ else:
188
+ Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN)
189
+ lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN)
190
+ lse = tl.where(lse == -float("inf"), 0.0, lse)
191
+ lse = lse[:, None]
192
+
193
+ # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
194
+ # KV_IDX and KV_NUM_BLKS are always contiguous.
195
+ kv_indices = KV_IDX + sparse_kv_idx_offset
196
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
197
+ sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
198
+
199
+ offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
200
+ dq = bwd_dq_inner(
201
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
202
+ K, V,
203
+ dq, q, do, Di, lse,
204
+ off_zq, off_hq2, offs_m2, offs_n2,
205
+ stride_kn, stride_kd, stride_vn, stride_vd,
206
+ kv_indices, sparse_kv_num_blocks,
207
+ MATMUL_PRECISION,
208
+ IS_FULL_BLOCKS=False,
209
+ )
210
+
211
+ if HAS_FULL_BLOCKS:
212
+ # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
213
+ # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
214
+ kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
215
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
216
+ sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
217
+
218
+ offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
219
+ dq = bwd_dq_inner(
220
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
221
+ K, V,
222
+ dq, q, do, Di, lse,
223
+ off_zq, off_hq2, offs_m2, offs_n2,
224
+ stride_kn, stride_kd, stride_vn, stride_vd,
225
+ kv_indices, sparse_kv_num_blocks,
226
+ MATMUL_PRECISION,
227
+ IS_FULL_BLOCKS=True,
228
+ )
229
+
230
+ # Write back dQ.
231
+ dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd
232
+ dq *= SM_SCALE
233
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
234
+ tl.store(dq_ptrs, dq)
235
+ else:
236
+ tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM))
237
+ else:
238
+ # THIS BLOCK DOES DK & DV
239
+ SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
240
+ SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1)
241
+
242
+ pid_mask = pid // SPARSE_KV_MULTIPLE
243
+
244
+ stride_q_num_blks_h = ks2
245
+ stride_q_idx_h = 16*ks3
246
+ stride_q_idx_n = 16
247
+
248
+
249
+ dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
250
+ dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
251
+
252
+ start_n1 = pid * BLOCK_N1
253
+ offs_n1 = start_n1 + tl.arange(0, BLOCK_N1)
254
+
255
+ # load K and V: they stay in SRAM throughout the inner loop.
256
+ k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
257
+ v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
258
+
259
+ if PRESCALE_QK:
260
+ k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
261
+
262
+ for off_g in range(0, GQA_SHARED_HEADS):
263
+ off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g
264
+
265
+ # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
266
+ q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64)
267
+ do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64)
268
+ dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64)
269
+ off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64)
270
+
271
+ Q1 = Q + q_adj1
272
+ DO1 = DO + do_adj1
273
+ # TODO: This does not work if DQ is not the same layout as Q (for example,
274
+ # if Q is broadcasted)
275
+ LSE1 = LSE + off_chz1
276
+ DELTA1 = DELTA + off_chz1
277
+
278
+ sparse_idx_hq1 = off_hq1 % SPARSE_HQ
279
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1
280
+
281
+ sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask
282
+ sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950
283
+
284
+ # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
285
+ # Q_IDX and Q_NUM_BLKS are always contiguous.
286
+ q_indices = Q_IDX + sparse_q_idx_offset
287
+ q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
288
+ sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset)
289
+
290
+ offs_m1 = q_start + tl.arange(0, BLOCK_M1)
291
+ dk, dv = bwd_dkdv_inner(
292
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
293
+ Q1, DO1, DELTA1, LSE1,
294
+ dk, dv, k, v,
295
+ off_zq, off_hq1, offs_n1, offs_m1,
296
+ stride_qm, stride_qd, stride_dom, stride_dod,
297
+ q_indices, sparse_q_num_blocks,
298
+ MATMUL_PRECISION,
299
+ IS_FULL_BLOCKS=False,
300
+ )
301
+
302
+
303
+ if HAS_FULL_BLOCKS:
304
+ # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
305
+ # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous.
306
+ q_indices = FULL_Q_IDX + sparse_q_idx_offset
307
+ q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
308
+ sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset)
309
+
310
+ offs_m1 = q_start + tl.arange(0, BLOCK_M1)
311
+ dk, dv = bwd_dkdv_inner(
312
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
313
+ Q1, DO1, DELTA1, LSE1,
314
+ dk, dv, k, v,
315
+ off_zq, off_hq1, offs_n1, offs_m1,
316
+ stride_qm, stride_qd, stride_dom, stride_dod,
317
+ q_indices, sparse_q_num_blocks,
318
+ MATMUL_PRECISION,
319
+ IS_FULL_BLOCKS=True,
320
+ )
321
+
322
+ # Write back dV and dK.
323
+ dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd
324
+
325
+ index_n = offs_n1[:, None]
326
+ index_k = offs_k[None, :]
327
+ index_v = offs_v[None, :]
328
+
329
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
330
+ tl.store(dv_ptrs, dv)
331
+ else:
332
+ tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM))
333
+
334
+ dk *= SM_SCALE
335
+
336
+ if SAFE_HEAD_DIM:
337
+ mask = index_n < KV_LEN
338
+ else:
339
+ mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM)
340
+
341
+ # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
342
+ # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
343
+ tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED])
344
+ xindex = index_k + 128*index_n + 128*off_hkv*ks0 + 1024*off_zq*ks0
345
+ tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask)
346
+
347
+ @triton.jit
348
+ def bwd_dq_inner(
349
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
350
+ K, V, # pointers
351
+ dq, q, do, Di, lse,
352
+ off_z, off_hq, offs_m2, offs_n2,
353
+ stride_kn, stride_kd, stride_vn, stride_vd,
354
+ kv_indices, sparse_kv_num_blocks,
355
+ MATMUL_PRECISION,
356
+ IS_FULL_BLOCKS,
357
+ ):
358
+ PRESCALE_QK : tl.constexpr = False
359
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
360
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
361
+ WRITE_DQ : tl.constexpr = True
362
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
363
+ OUTPUT_MAX : tl.constexpr = False
364
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
365
+ IS_DIVISIBLE : tl.constexpr = False
366
+ SM_SCALE : tl.constexpr = 0.08838834764831843
367
+ GQA_SHARED_HEADS : tl.constexpr = 4
368
+ HAS_FULL_BLOCKS : tl.constexpr = True
369
+ QK_HEAD_DIM : tl.constexpr = 128
370
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
371
+ V_HEAD_DIM : tl.constexpr = 128
372
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
373
+ SAFE_HEAD_DIM : tl.constexpr = True
374
+ BLOCK_M1 : tl.constexpr = 64
375
+ BLOCK_N1 : tl.constexpr = 128
376
+ BLOCK_M2 : tl.constexpr = 128
377
+ BLOCK_N2 : tl.constexpr = 64
378
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
379
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
380
+ INDEX_DTYPE : tl.constexpr = tl.int32
381
+
382
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
383
+ RCP_LN2: tl.constexpr = 1.44269504
384
+ Q_LEN = 2048
385
+ KV_LEN = ks0
386
+
387
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
388
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
389
+
390
+ kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd
391
+ vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd
392
+ # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
393
+ tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
394
+
395
+ hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1))
396
+
397
+ for start_n in range(0, hi):
398
+ dq = bwd_dq_block_mn(
399
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
400
+ dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
401
+ off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
402
+ stride_kn, stride_kd, stride_vn, stride_vd,
403
+ kv_indices, sparse_kv_num_blocks,
404
+ MATMUL_PRECISION, RCP_LN2,
405
+ IS_FULL_BLOCKS,
406
+ )
407
+
408
+ # Increment pointers.
409
+ offset = get_offset_for_next_block(
410
+ start_n, kv_indices, sparse_kv_num_blocks,
411
+ SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS
412
+ )
413
+
414
+ kT_ptrs += offset * stride_kn
415
+ vT_ptrs += offset * stride_vn
416
+
417
+ offs_n2 += offset
418
+
419
+ return dq
420
+
421
+
422
+ @triton.jit
423
+ def bwd_dq_block_mn(
424
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
425
+ dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
426
+ off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
427
+ stride_kn, stride_kd, stride_vn, stride_vd,
428
+ kv_indices, sparse_kv_num_blocks,
429
+ MATMUL_PRECISION, RCP_LN2,
430
+ IS_FULL_BLOCKS,
431
+ ):
432
+ PRESCALE_QK : tl.constexpr = False
433
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
434
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
435
+ WRITE_DQ : tl.constexpr = True
436
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
437
+ OUTPUT_MAX : tl.constexpr = False
438
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
439
+ IS_DIVISIBLE : tl.constexpr = False
440
+ SM_SCALE : tl.constexpr = 0.08838834764831843
441
+ GQA_SHARED_HEADS : tl.constexpr = 4
442
+ HAS_FULL_BLOCKS : tl.constexpr = True
443
+ QK_HEAD_DIM : tl.constexpr = 128
444
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
445
+ V_HEAD_DIM : tl.constexpr = 128
446
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
447
+ SAFE_HEAD_DIM : tl.constexpr = True
448
+ BLOCK_M1 : tl.constexpr = 64
449
+ BLOCK_N1 : tl.constexpr = 128
450
+ BLOCK_M2 : tl.constexpr = 128
451
+ BLOCK_N2 : tl.constexpr = 64
452
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
453
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
454
+ INDEX_DTYPE : tl.constexpr = tl.int32
455
+
456
+
457
+ # NB reversed order to since K is transposed
458
+ kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN)
459
+ qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION)
460
+ if not PRESCALE_QK:
461
+ qk *= SM_SCALE
462
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
463
+ pre_mod_scores = qk
464
+ n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None)
465
+ # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim
466
+ # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary
467
+ m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None)
468
+
469
+ tmp0 = (qk)
470
+ post_mod_scores = tmp0
471
+
472
+
473
+
474
+
475
+ if not IS_DIVISIBLE:
476
+ post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf"))
477
+
478
+ if not IS_FULL_BLOCKS:
479
+ tmp1 = tl.full([1], False, tl.int1)
480
+ tmp2 = (m)
481
+ tmp3 = (n)
482
+ tmp4 = tmp2 >= tmp3
483
+ tmp5 = tmp3.to(tl.int64)
484
+ tmp6 = (off_z)
485
+ tmp7 = tl.load(in_ptr16 + tmp6)
486
+ tmp8 = tmp5 < tmp7
487
+ tmp9 = tmp2.to(tl.int64)
488
+ tmp10 = tmp9 < tmp7
489
+ tmp11 = tmp8 & tmp10
490
+ tmp12 = tmp4 & tmp11
491
+ tmp13 = tmp1 | tmp12
492
+ tmp14 = tl.full([1], 2048, tl.int32)
493
+ tmp15 = tmp3 >= tmp14
494
+ tmp16 = (tmp3 % tmp14)
495
+ tmp17 = tl.full([1], 0, tl.int32)
496
+ tmp18 = tmp16 != tmp17
497
+ tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0
498
+ tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0
499
+ tmp21 = tmp19 != tmp20
500
+ tmp22 = tmp18 & tmp21
501
+ tmp23 = tmp16 + tmp14
502
+ tmp24 = tl.where(tmp22, tmp23, tmp16)
503
+ tmp25 = tmp24.to(tl.int64)
504
+ tmp26 = tmp25 < tmp7
505
+ tmp27 = tmp15 & tmp26
506
+ tmp28 = tmp3 - tmp2
507
+ tmp29 = (tmp28 % tmp14)
508
+ tmp30 = tmp29 != tmp17
509
+ tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0
510
+ tmp32 = tmp31 != tmp20
511
+ tmp33 = tmp30 & tmp32
512
+ tmp34 = tmp29 + tmp14
513
+ tmp35 = tl.where(tmp33, tmp34, tmp29)
514
+ tmp36 = tmp35 == tmp17
515
+ tmp37 = tmp27 & tmp36
516
+ tmp38 = tmp13 | tmp37
517
+ mask_mod_output = tmp38
518
+
519
+
520
+ # apply mask for partial masked block
521
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
522
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
523
+ if not PRESCALE_QK:
524
+ post_mod_scores *= RCP_LN2
525
+ p = tl.math.exp2(post_mod_scores - lse)
526
+ # Compute dP and dS.
527
+ # NB reversed order to since V is transposed
528
+ vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN)
529
+
530
+ dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION)
531
+ ds = p * (dp - Di[:, None])
532
+ # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
533
+ tmp39 = (ds)
534
+ grad_scores = tmp39
535
+
536
+
537
+ if not IS_DIVISIBLE:
538
+ grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0)
539
+
540
+ # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
541
+ if WRITE_DQ:
542
+ scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN)
543
+
544
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
545
+ ds = grad_scores
546
+
547
+ if not IS_FULL_BLOCKS:
548
+ # (grads) apply mask for partially unmasked block
549
+ ds = tl.where(mask_mod_output, ds, 0.0)
550
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
551
+ ds = ds.to(MATMUL_PRECISION)
552
+ # Compute dQ.
553
+ dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION)
554
+
555
+ return dq
556
+
557
+
558
+ @triton.jit
559
+ def bwd_dkdv_inner(
560
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
561
+ Q, DO, DELTA, LSE, # pointers
562
+ dk, dv, k, v,
563
+ off_z, off_hq, offs_n1, offs_m1,
564
+ stride_qm, stride_qd, stride_dom, stride_dod,
565
+ q_indices, sparse_q_num_blocks,
566
+ MATMUL_PRECISION,
567
+ IS_FULL_BLOCKS,
568
+ ):
569
+ PRESCALE_QK : tl.constexpr = False
570
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
571
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
572
+ WRITE_DQ : tl.constexpr = True
573
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
574
+ OUTPUT_MAX : tl.constexpr = False
575
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
576
+ IS_DIVISIBLE : tl.constexpr = False
577
+ SM_SCALE : tl.constexpr = 0.08838834764831843
578
+ GQA_SHARED_HEADS : tl.constexpr = 4
579
+ HAS_FULL_BLOCKS : tl.constexpr = True
580
+ QK_HEAD_DIM : tl.constexpr = 128
581
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
582
+ V_HEAD_DIM : tl.constexpr = 128
583
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
584
+ SAFE_HEAD_DIM : tl.constexpr = True
585
+ BLOCK_M1 : tl.constexpr = 64
586
+ BLOCK_N1 : tl.constexpr = 128
587
+ BLOCK_M2 : tl.constexpr = 128
588
+ BLOCK_N2 : tl.constexpr = 64
589
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
590
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
591
+ INDEX_DTYPE : tl.constexpr = tl.int32
592
+
593
+ SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
594
+ RCP_LN2: tl.constexpr = 1.44269504
595
+ Q_LEN = 2048
596
+ KV_LEN = ks0
597
+
598
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
599
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
600
+
601
+ qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd
602
+ do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod
603
+ # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
604
+ tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
605
+
606
+ # The minimum is needed to handle the case where we run with a super large
607
+ # SPARSE_BLOCK_SIZE (i.e. no block-mask!)
608
+ hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1))
609
+
610
+ for start_m in range(0, hi):
611
+ dk, dv = bwd_dkdv_block_mn(
612
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
613
+ dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
614
+ off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
615
+ stride_qm, stride_qd, stride_dom, stride_dod,
616
+ q_indices, sparse_q_num_blocks,
617
+ MATMUL_PRECISION, RCP_LN2,
618
+ IS_FULL_BLOCKS,
619
+ )
620
+ # Increment pointers.
621
+ offset = get_offset_for_next_block(
622
+ start_m, q_indices, sparse_q_num_blocks,
623
+ SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS
624
+ )
625
+
626
+ qT_ptrs += offset * stride_qm
627
+ do_ptrs += offset * stride_dom
628
+ offs_m1 += offset
629
+
630
+ return dk, dv
631
+
632
+
633
+ @triton.jit
634
+ def bwd_dkdv_block_mn(
635
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
636
+ dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
637
+ off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
638
+ stride_qm, stride_qd, stride_dom, stride_dod,
639
+ q_indices, sparse_q_num_blocks,
640
+ MATMUL_PRECISION, RCP_LN2,
641
+ IS_FULL_BLOCKS,
642
+ ):
643
+ PRESCALE_QK : tl.constexpr = False
644
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
645
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
646
+ WRITE_DQ : tl.constexpr = True
647
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
648
+ OUTPUT_MAX : tl.constexpr = False
649
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
650
+ IS_DIVISIBLE : tl.constexpr = False
651
+ SM_SCALE : tl.constexpr = 0.08838834764831843
652
+ GQA_SHARED_HEADS : tl.constexpr = 4
653
+ HAS_FULL_BLOCKS : tl.constexpr = True
654
+ QK_HEAD_DIM : tl.constexpr = 128
655
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
656
+ V_HEAD_DIM : tl.constexpr = 128
657
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
658
+ SAFE_HEAD_DIM : tl.constexpr = True
659
+ BLOCK_M1 : tl.constexpr = 64
660
+ BLOCK_N1 : tl.constexpr = 128
661
+ BLOCK_M2 : tl.constexpr = 128
662
+ BLOCK_N2 : tl.constexpr = 64
663
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
664
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
665
+ INDEX_DTYPE : tl.constexpr = tl.int32
666
+
667
+
668
+ # NB reversed order since Q is transposed
669
+ qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN)
670
+ # Load LSE before computing qk to reduce pipeline stall.
671
+ if IS_DIVISIBLE:
672
+ lse = tl.load(LSE + offs_m1)
673
+ else:
674
+ lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN)
675
+ lse = tl.where(lse == -float("inf"), 0.0, lse)
676
+ qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION)
677
+ if not PRESCALE_QK:
678
+ qkT *= SM_SCALE
679
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
680
+ m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None)
681
+ # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim
682
+ # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary
683
+ n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None)
684
+
685
+ pre_mod_scores = qkT
686
+ tmp40 = (qkT)
687
+ post_mod_scores = tmp40
688
+
689
+
690
+
691
+ if not IS_DIVISIBLE:
692
+ post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf"))
693
+
694
+ if not IS_FULL_BLOCKS:
695
+ tmp41 = tl.full([1], False, tl.int1)
696
+ tmp42 = (m)
697
+ tmp43 = (n)
698
+ tmp44 = tmp42 >= tmp43
699
+ tmp45 = tmp43.to(tl.int64)
700
+ tmp46 = (off_z)
701
+ tmp47 = tl.load(in_ptr16 + tmp46)
702
+ tmp48 = tmp45 < tmp47
703
+ tmp49 = tmp42.to(tl.int64)
704
+ tmp50 = tmp49 < tmp47
705
+ tmp51 = tmp48 & tmp50
706
+ tmp52 = tmp44 & tmp51
707
+ tmp53 = tmp41 | tmp52
708
+ tmp54 = tl.full([1], 2048, tl.int32)
709
+ tmp55 = tmp43 >= tmp54
710
+ tmp56 = (tmp43 % tmp54)
711
+ tmp57 = tl.full([1], 0, tl.int32)
712
+ tmp58 = tmp56 != tmp57
713
+ tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0
714
+ tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0
715
+ tmp61 = tmp59 != tmp60
716
+ tmp62 = tmp58 & tmp61
717
+ tmp63 = tmp56 + tmp54
718
+ tmp64 = tl.where(tmp62, tmp63, tmp56)
719
+ tmp65 = tmp64.to(tl.int64)
720
+ tmp66 = tmp65 < tmp47
721
+ tmp67 = tmp55 & tmp66
722
+ tmp68 = tmp43 - tmp42
723
+ tmp69 = (tmp68 % tmp54)
724
+ tmp70 = tmp69 != tmp57
725
+ tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0
726
+ tmp72 = tmp71 != tmp60
727
+ tmp73 = tmp70 & tmp72
728
+ tmp74 = tmp69 + tmp54
729
+ tmp75 = tl.where(tmp73, tmp74, tmp69)
730
+ tmp76 = tmp75 == tmp57
731
+ tmp77 = tmp67 & tmp76
732
+ tmp78 = tmp53 | tmp77
733
+ mask_mod_output = tmp78
734
+
735
+ # (grads) apply mask for fully masked block
736
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
737
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
738
+ if not PRESCALE_QK:
739
+ post_mod_scores *= RCP_LN2
740
+ pT = tl.math.exp2(post_mod_scores - lse[None, :])
741
+ do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
742
+ # Compute dV.
743
+ ppT = pT
744
+ dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION)
745
+ if IS_DIVISIBLE:
746
+ Di = tl.load(DELTA + offs_m1)
747
+ else:
748
+ Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN)
749
+ # Compute dP and dS.
750
+ dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION)
751
+ dsT = pT * (dpT - Di[None, :])
752
+ # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
753
+ tmp79 = (dsT)
754
+ grad_scores = tmp79
755
+
756
+
757
+
758
+ if not IS_DIVISIBLE:
759
+ grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0)
760
+
761
+ # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
762
+ if not WRITE_DQ:
763
+ idx_b = off_z
764
+ idx_h = off_hq
765
+ idx_m = m
766
+ idx_n = n
767
+ scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN)
768
+
769
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
770
+ dsT = grad_scores
771
+ if not IS_FULL_BLOCKS:
772
+ # (grads) apply mask for partially unmasked block
773
+ dsT = tl.where(mask_mod_output, dsT, 0.0)
774
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
775
+ dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION)
776
+
777
+ return dk, dv
778
+
779
+ # Utility triton funcs
780
+ @triton.jit
781
+ def get_offset_for_next_block(
782
+ loop_iter, col_indices, total_blocks,
783
+ SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
784
+ BLOCKS_ARE_CONTIGUOUS: tl.constexpr
785
+ ):
786
+ if BLOCKS_ARE_CONTIGUOUS:
787
+ return BLOCK
788
+ cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
789
+ cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
790
+ next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
791
+ needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
792
+ jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
793
+ offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
794
+ return offset
795
+
796
+ @triton.jit
797
+ def get_bounded_indices(indices, max_len=None):
798
+ return indices % max_len if max_len is not None else indices
799
+
800
+ @triton.jit
801
+ def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
802
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
803
+ return tl.load(block_ptr)
804
+ elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
805
+ return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
806
+ elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
807
+ return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
808
+ else:
809
+ return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
810
+
811
+ @triton.jit
812
+ def load_checked_2d(
813
+ ptr,
814
+ offs_m,
815
+ offs_n,
816
+ stride_m,
817
+ stride_n,
818
+ IS_DIVISIBLE_M: tl.constexpr,
819
+ IS_DIVISIBLE_N: tl.constexpr,
820
+ M_LEN: tl.constexpr,
821
+ N_LEN: tl.constexpr,
822
+ ):
823
+ # Calculate final pointer if strides are provided
824
+ if stride_m is not None and stride_n is not None:
825
+ ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
826
+
827
+ # Handle all masking cases
828
+ if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
829
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
830
+ elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
831
+ return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
832
+ elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
833
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
834
+ else: # Both divisible
835
+ return tl.load(ptr)
SpecForge-ext/cache/compiled_kernels/4x/c4xykt7eysbenti5r55drq4w7k6c7fih4ifrou2alyqcn6r5enon.py ADDED
@@ -0,0 +1,835 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+
9
+ @triton_heuristics.template(
10
+
11
+ num_stages=3,
12
+ num_warps=8,
13
+ triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'kernel_name': 'triton_tem_fused_zeros_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
15
+
16
+ )
17
+ @triton.jit
18
+ def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0):
19
+ PRESCALE_QK : tl.constexpr = False
20
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
21
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
22
+ WRITE_DQ : tl.constexpr = True
23
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
24
+ OUTPUT_MAX : tl.constexpr = False
25
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
26
+ IS_DIVISIBLE : tl.constexpr = True
27
+ SM_SCALE : tl.constexpr = 0.08838834764831843
28
+ GQA_SHARED_HEADS : tl.constexpr = 4
29
+ HAS_FULL_BLOCKS : tl.constexpr = True
30
+ QK_HEAD_DIM : tl.constexpr = 128
31
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
32
+ V_HEAD_DIM : tl.constexpr = 128
33
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
34
+ SAFE_HEAD_DIM : tl.constexpr = True
35
+ BLOCK_M1 : tl.constexpr = 64
36
+ BLOCK_N1 : tl.constexpr = 128
37
+ BLOCK_M2 : tl.constexpr = 128
38
+ BLOCK_N2 : tl.constexpr = 64
39
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
40
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
41
+ INDEX_DTYPE : tl.constexpr = tl.int32
42
+ Q = arg_Q
43
+ K = arg_K
44
+ V = arg_V
45
+ LSE = arg_LSE
46
+ DELTA = arg_DELTA
47
+ DO = arg_DO
48
+ DQ = arg_DQ
49
+ DV = arg_DV
50
+ KV_NUM_BLKS = arg_KV_NUM_BLKS
51
+ KV_IDX = arg_KV_IDX
52
+ Q_NUM_BLKS = arg_Q_NUM_BLKS
53
+ Q_IDX = arg_Q_IDX
54
+ FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
55
+ FULL_KV_IDX = arg_FULL_KV_IDX
56
+ FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS
57
+ FULL_Q_IDX = arg_FULL_Q_IDX
58
+
59
+ # Sub notation for this kernel:
60
+ #
61
+ # Q: Query, K: Key, V: Value
62
+ # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype)
63
+ # DELTA: Precomputed sum(OUT*DO, axis=-1)
64
+ # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value
65
+ # DK: Derivative of Key, is the written to via the store_output call due to some limitations with
66
+ # inductor codegen
67
+ # M: Number of queries, N: Number of keys/values
68
+ # QK_HEAD_DIM: The dimension of the query and key embeddings
69
+ # V_HEAD_DIM: The dimension of the value embeddings
70
+ # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim
71
+ # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
72
+ # (Modifiable) Performance tuning options
73
+ # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block.
74
+ # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V.
75
+ # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q.
76
+ # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block.
77
+ #
78
+ # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
79
+ # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
80
+ # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
81
+ # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query.
82
+ # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query.
83
+ # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
84
+ # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
85
+ # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query.
86
+ # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query.
87
+
88
+ # The below are kernel options that can be applied for certain score_mods,
89
+ # or involve a numerics vs. perf tradeoff
90
+ # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
91
+ # about 20% more numerical error, but slightly faster.
92
+
93
+ # Define strides of inputs
94
+ stride_qz, stride_qh, stride_qm, stride_qd = 8388608, 128, 4096, 1
95
+ stride_kz, stride_kh, stride_kn, stride_kd = 2097152, 262144, 128, 1
96
+ stride_vz, stride_vh, stride_vn, stride_vd = 2097152, 262144, 128, 1
97
+ stride_doz, stride_doh, stride_dom, stride_dod = 8388608, 262144, 128, 1
98
+
99
+ stride_dqz, stride_dqh, stride_dqm, stride_dqd = 8388608, 128, 4096, 1
100
+ stride_dvz, stride_dvh, stride_dvm, stride_dvd = 2097152, 262144, 128, 1
101
+
102
+ ZQ = 2
103
+ HQ = 32
104
+ HKV = 8
105
+ Q_LEN = 2048
106
+ ZKV = 2
107
+ KV_LEN = 2048
108
+
109
+ MATMUL_PRECISION = Q.dtype.element_ty
110
+
111
+ pid = tl.program_id(0).to(INDEX_DTYPE)
112
+ NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1)
113
+ NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2)
114
+
115
+ off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx
116
+ off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx
117
+ off_zkv = off_zq % ZKV # kv batch idx
118
+
119
+ SPARSE_Z = 2
120
+ SPARSE_HQ = 1
121
+
122
+ sparse_idx_z = off_zq % SPARSE_Z
123
+
124
+ k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64)
125
+ v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64)
126
+ # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
127
+ # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
128
+ dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64)
129
+
130
+ # offset K, V, DV pointers for batch/kv-head
131
+ K += k_adj
132
+ V += v_adj
133
+ DV += dv_adj
134
+
135
+ RCP_LN2 = 1.44269504
136
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
137
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
138
+
139
+ if pid >= NUM_KV_BLOCKS:
140
+ off_pid = pid - NUM_KV_BLOCKS
141
+ # THIS BLOCK DOES DQ
142
+ SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2)
143
+ SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
144
+ off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS
145
+ start_m2_block = off_pid % NUM_Q_BLOCKS
146
+ off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE
147
+ stride_kv_num_blks_h = 16
148
+ stride_kv_idx_h = 256
149
+ stride_kv_idx_m = 16
150
+
151
+ sparse_idx_hq2 = off_hq2 % SPARSE_HQ
152
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2
153
+
154
+ sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask
155
+ sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950
156
+
157
+ # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
158
+ q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64)
159
+ do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64)
160
+ dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64)
161
+ off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64)
162
+
163
+ Q2 = Q + q_adj2
164
+ DO2 = DO + do_adj2
165
+ # TODO: This does not work if DQ is not the same layout as Q (for example,
166
+ # if Q is broadcasted)
167
+ DQ2 = DQ + dq_adj2
168
+ LSE2 = LSE + off_chz2
169
+ DELTA2 = DELTA + off_chz2
170
+
171
+ # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32)
172
+ dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
173
+
174
+ start_m2 = start_m2_block * BLOCK_M2
175
+ offs_m2 = start_m2 + tl.arange(0, BLOCK_M2)
176
+
177
+ # load Q and do: they stay in SRAM throughout the inner loop.
178
+ q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
179
+ do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
180
+
181
+ if PRESCALE_QK:
182
+ q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
183
+
184
+ if IS_DIVISIBLE:
185
+ Di = tl.load(DELTA2 + offs_m2)
186
+ lse = tl.load(LSE2 + offs_m2)
187
+ else:
188
+ Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN)
189
+ lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN)
190
+ lse = tl.where(lse == -float("inf"), 0.0, lse)
191
+ lse = lse[:, None]
192
+
193
+ # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
194
+ # KV_IDX and KV_NUM_BLKS are always contiguous.
195
+ kv_indices = KV_IDX + sparse_kv_idx_offset
196
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
197
+ sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
198
+
199
+ offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
200
+ dq = bwd_dq_inner(
201
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
202
+ K, V,
203
+ dq, q, do, Di, lse,
204
+ off_zq, off_hq2, offs_m2, offs_n2,
205
+ stride_kn, stride_kd, stride_vn, stride_vd,
206
+ kv_indices, sparse_kv_num_blocks,
207
+ MATMUL_PRECISION,
208
+ IS_FULL_BLOCKS=False,
209
+ )
210
+
211
+ if HAS_FULL_BLOCKS:
212
+ # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
213
+ # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
214
+ kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
215
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
216
+ sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
217
+
218
+ offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
219
+ dq = bwd_dq_inner(
220
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
221
+ K, V,
222
+ dq, q, do, Di, lse,
223
+ off_zq, off_hq2, offs_m2, offs_n2,
224
+ stride_kn, stride_kd, stride_vn, stride_vd,
225
+ kv_indices, sparse_kv_num_blocks,
226
+ MATMUL_PRECISION,
227
+ IS_FULL_BLOCKS=True,
228
+ )
229
+
230
+ # Write back dQ.
231
+ dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd
232
+ dq *= SM_SCALE
233
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
234
+ tl.store(dq_ptrs, dq)
235
+ else:
236
+ tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM))
237
+ else:
238
+ # THIS BLOCK DOES DK & DV
239
+ SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
240
+ SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1)
241
+
242
+ pid_mask = pid // SPARSE_KV_MULTIPLE
243
+
244
+ stride_q_num_blks_h = 16
245
+ stride_q_idx_h = 256
246
+ stride_q_idx_n = 16
247
+
248
+
249
+ dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
250
+ dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
251
+
252
+ start_n1 = pid * BLOCK_N1
253
+ offs_n1 = start_n1 + tl.arange(0, BLOCK_N1)
254
+
255
+ # load K and V: they stay in SRAM throughout the inner loop.
256
+ k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
257
+ v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
258
+
259
+ if PRESCALE_QK:
260
+ k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
261
+
262
+ for off_g in range(0, GQA_SHARED_HEADS):
263
+ off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g
264
+
265
+ # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
266
+ q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64)
267
+ do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64)
268
+ dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64)
269
+ off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64)
270
+
271
+ Q1 = Q + q_adj1
272
+ DO1 = DO + do_adj1
273
+ # TODO: This does not work if DQ is not the same layout as Q (for example,
274
+ # if Q is broadcasted)
275
+ LSE1 = LSE + off_chz1
276
+ DELTA1 = DELTA + off_chz1
277
+
278
+ sparse_idx_hq1 = off_hq1 % SPARSE_HQ
279
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1
280
+
281
+ sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask
282
+ sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950
283
+
284
+ # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
285
+ # Q_IDX and Q_NUM_BLKS are always contiguous.
286
+ q_indices = Q_IDX + sparse_q_idx_offset
287
+ q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
288
+ sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset)
289
+
290
+ offs_m1 = q_start + tl.arange(0, BLOCK_M1)
291
+ dk, dv = bwd_dkdv_inner(
292
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
293
+ Q1, DO1, DELTA1, LSE1,
294
+ dk, dv, k, v,
295
+ off_zq, off_hq1, offs_n1, offs_m1,
296
+ stride_qm, stride_qd, stride_dom, stride_dod,
297
+ q_indices, sparse_q_num_blocks,
298
+ MATMUL_PRECISION,
299
+ IS_FULL_BLOCKS=False,
300
+ )
301
+
302
+
303
+ if HAS_FULL_BLOCKS:
304
+ # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
305
+ # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous.
306
+ q_indices = FULL_Q_IDX + sparse_q_idx_offset
307
+ q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
308
+ sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset)
309
+
310
+ offs_m1 = q_start + tl.arange(0, BLOCK_M1)
311
+ dk, dv = bwd_dkdv_inner(
312
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
313
+ Q1, DO1, DELTA1, LSE1,
314
+ dk, dv, k, v,
315
+ off_zq, off_hq1, offs_n1, offs_m1,
316
+ stride_qm, stride_qd, stride_dom, stride_dod,
317
+ q_indices, sparse_q_num_blocks,
318
+ MATMUL_PRECISION,
319
+ IS_FULL_BLOCKS=True,
320
+ )
321
+
322
+ # Write back dV and dK.
323
+ dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd
324
+
325
+ index_n = offs_n1[:, None]
326
+ index_k = offs_k[None, :]
327
+ index_v = offs_v[None, :]
328
+
329
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
330
+ tl.store(dv_ptrs, dv)
331
+ else:
332
+ tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM))
333
+
334
+ dk *= SM_SCALE
335
+
336
+ if SAFE_HEAD_DIM:
337
+ mask = index_n < KV_LEN
338
+ else:
339
+ mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM)
340
+
341
+ # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
342
+ # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
343
+ tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED])
344
+ xindex = index_k + 128*index_n + 262144*off_hkv + 2097152*off_zq
345
+ tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask)
346
+
347
+ @triton.jit
348
+ def bwd_dq_inner(
349
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
350
+ K, V, # pointers
351
+ dq, q, do, Di, lse,
352
+ off_z, off_hq, offs_m2, offs_n2,
353
+ stride_kn, stride_kd, stride_vn, stride_vd,
354
+ kv_indices, sparse_kv_num_blocks,
355
+ MATMUL_PRECISION,
356
+ IS_FULL_BLOCKS,
357
+ ):
358
+ PRESCALE_QK : tl.constexpr = False
359
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
360
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
361
+ WRITE_DQ : tl.constexpr = True
362
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
363
+ OUTPUT_MAX : tl.constexpr = False
364
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
365
+ IS_DIVISIBLE : tl.constexpr = True
366
+ SM_SCALE : tl.constexpr = 0.08838834764831843
367
+ GQA_SHARED_HEADS : tl.constexpr = 4
368
+ HAS_FULL_BLOCKS : tl.constexpr = True
369
+ QK_HEAD_DIM : tl.constexpr = 128
370
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
371
+ V_HEAD_DIM : tl.constexpr = 128
372
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
373
+ SAFE_HEAD_DIM : tl.constexpr = True
374
+ BLOCK_M1 : tl.constexpr = 64
375
+ BLOCK_N1 : tl.constexpr = 128
376
+ BLOCK_M2 : tl.constexpr = 128
377
+ BLOCK_N2 : tl.constexpr = 64
378
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
379
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
380
+ INDEX_DTYPE : tl.constexpr = tl.int32
381
+
382
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
383
+ RCP_LN2: tl.constexpr = 1.44269504
384
+ Q_LEN = 2048
385
+ KV_LEN = 2048
386
+
387
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
388
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
389
+
390
+ kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd
391
+ vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd
392
+ # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
393
+ tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
394
+
395
+ hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1))
396
+
397
+ for start_n in range(0, hi):
398
+ dq = bwd_dq_block_mn(
399
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
400
+ dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
401
+ off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
402
+ stride_kn, stride_kd, stride_vn, stride_vd,
403
+ kv_indices, sparse_kv_num_blocks,
404
+ MATMUL_PRECISION, RCP_LN2,
405
+ IS_FULL_BLOCKS,
406
+ )
407
+
408
+ # Increment pointers.
409
+ offset = get_offset_for_next_block(
410
+ start_n, kv_indices, sparse_kv_num_blocks,
411
+ SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS
412
+ )
413
+
414
+ kT_ptrs += offset * stride_kn
415
+ vT_ptrs += offset * stride_vn
416
+
417
+ offs_n2 += offset
418
+
419
+ return dq
420
+
421
+
422
+ @triton.jit
423
+ def bwd_dq_block_mn(
424
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
425
+ dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
426
+ off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
427
+ stride_kn, stride_kd, stride_vn, stride_vd,
428
+ kv_indices, sparse_kv_num_blocks,
429
+ MATMUL_PRECISION, RCP_LN2,
430
+ IS_FULL_BLOCKS,
431
+ ):
432
+ PRESCALE_QK : tl.constexpr = False
433
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
434
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
435
+ WRITE_DQ : tl.constexpr = True
436
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
437
+ OUTPUT_MAX : tl.constexpr = False
438
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
439
+ IS_DIVISIBLE : tl.constexpr = True
440
+ SM_SCALE : tl.constexpr = 0.08838834764831843
441
+ GQA_SHARED_HEADS : tl.constexpr = 4
442
+ HAS_FULL_BLOCKS : tl.constexpr = True
443
+ QK_HEAD_DIM : tl.constexpr = 128
444
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
445
+ V_HEAD_DIM : tl.constexpr = 128
446
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
447
+ SAFE_HEAD_DIM : tl.constexpr = True
448
+ BLOCK_M1 : tl.constexpr = 64
449
+ BLOCK_N1 : tl.constexpr = 128
450
+ BLOCK_M2 : tl.constexpr = 128
451
+ BLOCK_N2 : tl.constexpr = 64
452
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
453
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
454
+ INDEX_DTYPE : tl.constexpr = tl.int32
455
+
456
+
457
+ # NB reversed order to since K is transposed
458
+ kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN)
459
+ qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION)
460
+ if not PRESCALE_QK:
461
+ qk *= SM_SCALE
462
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
463
+ pre_mod_scores = qk
464
+ n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None)
465
+ # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim
466
+ # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary
467
+ m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None)
468
+
469
+ tmp0 = (qk)
470
+ post_mod_scores = tmp0
471
+
472
+
473
+
474
+
475
+ if not IS_DIVISIBLE:
476
+ post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf"))
477
+
478
+ if not IS_FULL_BLOCKS:
479
+ tmp1 = tl.full([1], False, tl.int1)
480
+ tmp2 = (m)
481
+ tmp3 = (n)
482
+ tmp4 = tmp2 >= tmp3
483
+ tmp5 = tmp3.to(tl.int64)
484
+ tmp6 = (off_z)
485
+ tmp7 = tl.load(in_ptr16 + tmp6)
486
+ tmp8 = tmp5 < tmp7
487
+ tmp9 = tmp2.to(tl.int64)
488
+ tmp10 = tmp9 < tmp7
489
+ tmp11 = tmp8 & tmp10
490
+ tmp12 = tmp4 & tmp11
491
+ tmp13 = tmp1 | tmp12
492
+ tmp14 = tl.full([1], 2048, tl.int32)
493
+ tmp15 = tmp3 >= tmp14
494
+ tmp16 = (tmp3 % tmp14)
495
+ tmp17 = tl.full([1], 0, tl.int32)
496
+ tmp18 = tmp16 != tmp17
497
+ tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0
498
+ tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0
499
+ tmp21 = tmp19 != tmp20
500
+ tmp22 = tmp18 & tmp21
501
+ tmp23 = tmp16 + tmp14
502
+ tmp24 = tl.where(tmp22, tmp23, tmp16)
503
+ tmp25 = tmp24.to(tl.int64)
504
+ tmp26 = tmp25 < tmp7
505
+ tmp27 = tmp15 & tmp26
506
+ tmp28 = tmp3 - tmp2
507
+ tmp29 = (tmp28 % tmp14)
508
+ tmp30 = tmp29 != tmp17
509
+ tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0
510
+ tmp32 = tmp31 != tmp20
511
+ tmp33 = tmp30 & tmp32
512
+ tmp34 = tmp29 + tmp14
513
+ tmp35 = tl.where(tmp33, tmp34, tmp29)
514
+ tmp36 = tmp35 == tmp17
515
+ tmp37 = tmp27 & tmp36
516
+ tmp38 = tmp13 | tmp37
517
+ mask_mod_output = tmp38
518
+
519
+
520
+ # apply mask for partial masked block
521
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
522
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
523
+ if not PRESCALE_QK:
524
+ post_mod_scores *= RCP_LN2
525
+ p = tl.math.exp2(post_mod_scores - lse)
526
+ # Compute dP and dS.
527
+ # NB reversed order to since V is transposed
528
+ vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN)
529
+
530
+ dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION)
531
+ ds = p * (dp - Di[:, None])
532
+ # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
533
+ tmp39 = (ds)
534
+ grad_scores = tmp39
535
+
536
+
537
+ if not IS_DIVISIBLE:
538
+ grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0)
539
+
540
+ # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
541
+ if WRITE_DQ:
542
+ scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN)
543
+
544
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
545
+ ds = grad_scores
546
+
547
+ if not IS_FULL_BLOCKS:
548
+ # (grads) apply mask for partially unmasked block
549
+ ds = tl.where(mask_mod_output, ds, 0.0)
550
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
551
+ ds = ds.to(MATMUL_PRECISION)
552
+ # Compute dQ.
553
+ dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION)
554
+
555
+ return dq
556
+
557
+
558
+ @triton.jit
559
+ def bwd_dkdv_inner(
560
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
561
+ Q, DO, DELTA, LSE, # pointers
562
+ dk, dv, k, v,
563
+ off_z, off_hq, offs_n1, offs_m1,
564
+ stride_qm, stride_qd, stride_dom, stride_dod,
565
+ q_indices, sparse_q_num_blocks,
566
+ MATMUL_PRECISION,
567
+ IS_FULL_BLOCKS,
568
+ ):
569
+ PRESCALE_QK : tl.constexpr = False
570
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
571
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
572
+ WRITE_DQ : tl.constexpr = True
573
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
574
+ OUTPUT_MAX : tl.constexpr = False
575
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
576
+ IS_DIVISIBLE : tl.constexpr = True
577
+ SM_SCALE : tl.constexpr = 0.08838834764831843
578
+ GQA_SHARED_HEADS : tl.constexpr = 4
579
+ HAS_FULL_BLOCKS : tl.constexpr = True
580
+ QK_HEAD_DIM : tl.constexpr = 128
581
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
582
+ V_HEAD_DIM : tl.constexpr = 128
583
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
584
+ SAFE_HEAD_DIM : tl.constexpr = True
585
+ BLOCK_M1 : tl.constexpr = 64
586
+ BLOCK_N1 : tl.constexpr = 128
587
+ BLOCK_M2 : tl.constexpr = 128
588
+ BLOCK_N2 : tl.constexpr = 64
589
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
590
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
591
+ INDEX_DTYPE : tl.constexpr = tl.int32
592
+
593
+ SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
594
+ RCP_LN2: tl.constexpr = 1.44269504
595
+ Q_LEN = 2048
596
+ KV_LEN = 2048
597
+
598
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
599
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
600
+
601
+ qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd
602
+ do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod
603
+ # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
604
+ tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
605
+
606
+ # The minimum is needed to handle the case where we run with a super large
607
+ # SPARSE_BLOCK_SIZE (i.e. no block-mask!)
608
+ hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1))
609
+
610
+ for start_m in range(0, hi):
611
+ dk, dv = bwd_dkdv_block_mn(
612
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
613
+ dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
614
+ off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
615
+ stride_qm, stride_qd, stride_dom, stride_dod,
616
+ q_indices, sparse_q_num_blocks,
617
+ MATMUL_PRECISION, RCP_LN2,
618
+ IS_FULL_BLOCKS,
619
+ )
620
+ # Increment pointers.
621
+ offset = get_offset_for_next_block(
622
+ start_m, q_indices, sparse_q_num_blocks,
623
+ SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS
624
+ )
625
+
626
+ qT_ptrs += offset * stride_qm
627
+ do_ptrs += offset * stride_dom
628
+ offs_m1 += offset
629
+
630
+ return dk, dv
631
+
632
+
633
+ @triton.jit
634
+ def bwd_dkdv_block_mn(
635
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
636
+ dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
637
+ off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
638
+ stride_qm, stride_qd, stride_dom, stride_dod,
639
+ q_indices, sparse_q_num_blocks,
640
+ MATMUL_PRECISION, RCP_LN2,
641
+ IS_FULL_BLOCKS,
642
+ ):
643
+ PRESCALE_QK : tl.constexpr = False
644
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
645
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
646
+ WRITE_DQ : tl.constexpr = True
647
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
648
+ OUTPUT_MAX : tl.constexpr = False
649
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
650
+ IS_DIVISIBLE : tl.constexpr = True
651
+ SM_SCALE : tl.constexpr = 0.08838834764831843
652
+ GQA_SHARED_HEADS : tl.constexpr = 4
653
+ HAS_FULL_BLOCKS : tl.constexpr = True
654
+ QK_HEAD_DIM : tl.constexpr = 128
655
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
656
+ V_HEAD_DIM : tl.constexpr = 128
657
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
658
+ SAFE_HEAD_DIM : tl.constexpr = True
659
+ BLOCK_M1 : tl.constexpr = 64
660
+ BLOCK_N1 : tl.constexpr = 128
661
+ BLOCK_M2 : tl.constexpr = 128
662
+ BLOCK_N2 : tl.constexpr = 64
663
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
664
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
665
+ INDEX_DTYPE : tl.constexpr = tl.int32
666
+
667
+
668
+ # NB reversed order since Q is transposed
669
+ qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN)
670
+ # Load LSE before computing qk to reduce pipeline stall.
671
+ if IS_DIVISIBLE:
672
+ lse = tl.load(LSE + offs_m1)
673
+ else:
674
+ lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN)
675
+ lse = tl.where(lse == -float("inf"), 0.0, lse)
676
+ qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION)
677
+ if not PRESCALE_QK:
678
+ qkT *= SM_SCALE
679
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
680
+ m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None)
681
+ # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim
682
+ # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary
683
+ n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None)
684
+
685
+ pre_mod_scores = qkT
686
+ tmp40 = (qkT)
687
+ post_mod_scores = tmp40
688
+
689
+
690
+
691
+ if not IS_DIVISIBLE:
692
+ post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf"))
693
+
694
+ if not IS_FULL_BLOCKS:
695
+ tmp41 = tl.full([1], False, tl.int1)
696
+ tmp42 = (m)
697
+ tmp43 = (n)
698
+ tmp44 = tmp42 >= tmp43
699
+ tmp45 = tmp43.to(tl.int64)
700
+ tmp46 = (off_z)
701
+ tmp47 = tl.load(in_ptr16 + tmp46)
702
+ tmp48 = tmp45 < tmp47
703
+ tmp49 = tmp42.to(tl.int64)
704
+ tmp50 = tmp49 < tmp47
705
+ tmp51 = tmp48 & tmp50
706
+ tmp52 = tmp44 & tmp51
707
+ tmp53 = tmp41 | tmp52
708
+ tmp54 = tl.full([1], 2048, tl.int32)
709
+ tmp55 = tmp43 >= tmp54
710
+ tmp56 = (tmp43 % tmp54)
711
+ tmp57 = tl.full([1], 0, tl.int32)
712
+ tmp58 = tmp56 != tmp57
713
+ tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0
714
+ tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0
715
+ tmp61 = tmp59 != tmp60
716
+ tmp62 = tmp58 & tmp61
717
+ tmp63 = tmp56 + tmp54
718
+ tmp64 = tl.where(tmp62, tmp63, tmp56)
719
+ tmp65 = tmp64.to(tl.int64)
720
+ tmp66 = tmp65 < tmp47
721
+ tmp67 = tmp55 & tmp66
722
+ tmp68 = tmp43 - tmp42
723
+ tmp69 = (tmp68 % tmp54)
724
+ tmp70 = tmp69 != tmp57
725
+ tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0
726
+ tmp72 = tmp71 != tmp60
727
+ tmp73 = tmp70 & tmp72
728
+ tmp74 = tmp69 + tmp54
729
+ tmp75 = tl.where(tmp73, tmp74, tmp69)
730
+ tmp76 = tmp75 == tmp57
731
+ tmp77 = tmp67 & tmp76
732
+ tmp78 = tmp53 | tmp77
733
+ mask_mod_output = tmp78
734
+
735
+ # (grads) apply mask for fully masked block
736
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
737
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
738
+ if not PRESCALE_QK:
739
+ post_mod_scores *= RCP_LN2
740
+ pT = tl.math.exp2(post_mod_scores - lse[None, :])
741
+ do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
742
+ # Compute dV.
743
+ ppT = pT
744
+ dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION)
745
+ if IS_DIVISIBLE:
746
+ Di = tl.load(DELTA + offs_m1)
747
+ else:
748
+ Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN)
749
+ # Compute dP and dS.
750
+ dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION)
751
+ dsT = pT * (dpT - Di[None, :])
752
+ # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
753
+ tmp79 = (dsT)
754
+ grad_scores = tmp79
755
+
756
+
757
+
758
+ if not IS_DIVISIBLE:
759
+ grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0)
760
+
761
+ # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
762
+ if not WRITE_DQ:
763
+ idx_b = off_z
764
+ idx_h = off_hq
765
+ idx_m = m
766
+ idx_n = n
767
+ scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN)
768
+
769
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
770
+ dsT = grad_scores
771
+ if not IS_FULL_BLOCKS:
772
+ # (grads) apply mask for partially unmasked block
773
+ dsT = tl.where(mask_mod_output, dsT, 0.0)
774
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
775
+ dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION)
776
+
777
+ return dk, dv
778
+
779
+ # Utility triton funcs
780
+ @triton.jit
781
+ def get_offset_for_next_block(
782
+ loop_iter, col_indices, total_blocks,
783
+ SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
784
+ BLOCKS_ARE_CONTIGUOUS: tl.constexpr
785
+ ):
786
+ if BLOCKS_ARE_CONTIGUOUS:
787
+ return BLOCK
788
+ cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
789
+ cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
790
+ next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
791
+ needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
792
+ jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
793
+ offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
794
+ return offset
795
+
796
+ @triton.jit
797
+ def get_bounded_indices(indices, max_len=None):
798
+ return indices % max_len if max_len is not None else indices
799
+
800
+ @triton.jit
801
+ def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
802
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
803
+ return tl.load(block_ptr)
804
+ elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
805
+ return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
806
+ elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
807
+ return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
808
+ else:
809
+ return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
810
+
811
+ @triton.jit
812
+ def load_checked_2d(
813
+ ptr,
814
+ offs_m,
815
+ offs_n,
816
+ stride_m,
817
+ stride_n,
818
+ IS_DIVISIBLE_M: tl.constexpr,
819
+ IS_DIVISIBLE_N: tl.constexpr,
820
+ M_LEN: tl.constexpr,
821
+ N_LEN: tl.constexpr,
822
+ ):
823
+ # Calculate final pointer if strides are provided
824
+ if stride_m is not None and stride_n is not None:
825
+ ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
826
+
827
+ # Handle all masking cases
828
+ if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
829
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
830
+ elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
831
+ return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
832
+ elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
833
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
834
+ else: # Both divisible
835
+ return tl.load(ptr)
SpecForge-ext/cache/compiled_kernels/54/c5464ptly4n22voq77yo3wrltmxhbase2ojnypkgcpcxg6js4oty.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.reduction(
11
+ size_hints={'x': 4096, 'r0_': 32768},
12
+ reduction_hint=ReductionHint.INNER,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 65536, 'r0_': 262144000}}
16
+ )
17
+ @triton.jit
18
+ def triton_red_fused_argmax_0(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
19
+ xnumel = 4096
20
+ r0_numel = 32000
21
+ rnumel = r0_numel
22
+ RBLOCK: tl.constexpr = R0_BLOCK
23
+ xoffset = tl.program_id(0) * XBLOCK
24
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
25
+ xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
26
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
27
+ rbase = r0_base
28
+ x0 = xindex
29
+ _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32)
30
+ _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32)
31
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
32
+ r0_index = r0_offset + r0_base
33
+ r0_mask = r0_index < r0_numel
34
+ roffset = r0_offset
35
+ rindex = r0_index
36
+ r0_1 = r0_index
37
+ tmp0 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
38
+ tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
39
+ _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index(
40
+ _tmp2, _tmp2_index, tmp1, rindex
41
+ )
42
+ _tmp2 = tl.where(r0_mask, _tmp2_next, _tmp2)
43
+ _tmp2_index = tl.where(r0_mask, _tmp2_index_next, _tmp2_index)
44
+ tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1)
45
+ tmp2 = tmp2_idx[:, None]
46
+ tl.store(out_ptr0 + (x0), tmp2, None)
SpecForge-ext/cache/compiled_kernels/54/c54p5bozrk7z3jkhpl6meytxfu7bz7ojmkijrdgczbq55oalwpgl.py ADDED
@@ -0,0 +1,552 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+
9
+ @triton_heuristics.template(
10
+
11
+ num_stages=3,
12
+ num_warps=8,
13
+ triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'in_ptr9': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
15
+
16
+ )
17
+ @triton.jit
18
+ def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1):
19
+ PRESCALE_QK : tl.constexpr = False
20
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
21
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
22
+ WRITE_DQ : tl.constexpr = True
23
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
24
+ OUTPUT_MAX : tl.constexpr = False
25
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
26
+ IS_DIVISIBLE : tl.constexpr = False
27
+ SM_SCALE : tl.constexpr = 0.08838834764831843
28
+ GQA_SHARED_HEADS : tl.constexpr = 4
29
+ HAS_FULL_BLOCKS : tl.constexpr = True
30
+ QK_HEAD_DIM : tl.constexpr = 128
31
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
32
+ V_HEAD_DIM : tl.constexpr = 128
33
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
34
+ SAFE_HEAD_DIM : tl.constexpr = True
35
+ USE_TMA : tl.constexpr = False
36
+ BLOCK_M : tl.constexpr = 128
37
+ BLOCK_N : tl.constexpr = 64
38
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
39
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
40
+ INDEX_DTYPE : tl.constexpr = tl.int32
41
+ Q = arg_Q
42
+ K = arg_K
43
+ V = arg_V
44
+ LSE = arg_LSE
45
+ MAX = arg_MAX
46
+ KV_NUM_BLKS = arg_KV_NUM_BLKS
47
+ KV_IDX = arg_KV_IDX
48
+ FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
49
+ FULL_KV_IDX = arg_FULL_KV_IDX
50
+
51
+ # Sub notation for this kernel:
52
+ #
53
+ # Q: Query, K: Key, V: Value
54
+ # M: Number of queries, N: Number of keys/values, D: Model dimension
55
+ # QK_HEAD_DIM: The dimension of the query and key embeddings
56
+ # V_HEAD_DIM: The dimension of the value embeddings
57
+ # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head
58
+ # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
59
+ #
60
+ # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
61
+ # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
62
+ # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
63
+ # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
64
+ # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
65
+ #
66
+ # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad
67
+ #
68
+ # (Modifiable) Performance tuning options
69
+ # BLOCK_M: The thread block size across the seqlen dim of Q.
70
+ # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block.
71
+
72
+ # The below are kernel options that can be applied for certain score_mods,
73
+ # or involve a numerics vs. perf tradeoff
74
+ # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
75
+ # about 20% more numerical error, but slightly faster.
76
+ # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row
77
+ # is not masked out? If so, we can skip an extra safety check
78
+ # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are
79
+ # contiguous? If so, we don't need to do an indirect jump for every block
80
+
81
+ tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0)
82
+ tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0)
83
+
84
+ # Define strides of inputs
85
+ stride_qz, stride_qh, stride_qm, stride_qk = 8388608, 128, 4096, 1
86
+ stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks0, 128*ks0, 128, 1
87
+ stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks0, 128*ks0, 128, 1
88
+
89
+ ZQ = 8
90
+ HQ = 32
91
+ Q_LEN = 2048
92
+ ZKV = 8
93
+ KV_LEN = ks0
94
+
95
+ MATMUL_PRECISION = Q.dtype.element_ty
96
+
97
+ q_start = tl.program_id(0).to(INDEX_DTYPE)
98
+ off_zq = tl.program_id(1).to(INDEX_DTYPE)
99
+ off_hq = tl.program_id(2).to(INDEX_DTYPE)
100
+
101
+ # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq.
102
+ # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0.
103
+ off_zkv = off_zq % ZKV
104
+ off_hkv = off_hq // GQA_SHARED_HEADS
105
+ off_g = off_hq % GQA_SHARED_HEADS
106
+
107
+ q_offset = off_zq * stride_qz + off_hq * stride_qh
108
+ k_offset = off_zkv * stride_kz + off_hkv * stride_kh
109
+ v_offset = off_zkv * stride_vz + off_hkv * stride_vh
110
+
111
+ Q = Q + q_offset
112
+ K = K + k_offset
113
+ V = V + v_offset
114
+
115
+ # Setting up the TMA descriptors for Q, K, V
116
+ desc_q = None
117
+ desc_k = None
118
+ desc_v = None
119
+
120
+ SPARSE_Z = 8
121
+ SPARSE_HQ = 1
122
+
123
+ sparse_idx_z = off_zq % SPARSE_Z
124
+ sparse_idx_hq = off_hq % SPARSE_HQ
125
+
126
+ SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M)
127
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
128
+
129
+ stride_kv_num_blks_h = 16
130
+ stride_kv_idx_h = 16*ks1
131
+ stride_kv_idx_m = ks1
132
+
133
+ # initialize pointer to m and l
134
+ m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
135
+ l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
136
+ acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
137
+
138
+ offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
139
+
140
+ # KV_IDX and KV_NUM_BLKS are always contiguous.
141
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq
142
+ sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE
143
+ sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950
144
+ offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
145
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
146
+ q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
147
+
148
+ # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
149
+ # We don't know anything "special" about these blocks, so we need to apply
150
+ # both score_mod and mask_mod to it
151
+ kv_indices = KV_IDX + sparse_kv_idx_offset
152
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
153
+ kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
154
+ block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
155
+
156
+
157
+ # K and V pointers will be passed directly to forward_inner
158
+
159
+ offs_n = kv_start + tl.arange(0, BLOCK_N)
160
+
161
+
162
+ acc, l_i, m_i = forward_inner(
163
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1,
164
+ q, K, V,
165
+ desc_k, desc_v, Q_LEN, KV_LEN,
166
+ acc, l_i, m_i,
167
+ off_zq, off_hq, offs_m[:, None], offs_n[None, :],
168
+ kv_start,
169
+ kv_indices, kv_num_blocks,
170
+ 0, block_n_end,
171
+ MATMUL_PRECISION,
172
+ stride_kk, stride_kn, stride_vn, stride_vk,
173
+ IS_FULL_BLOCKS=False,
174
+ )
175
+
176
+ # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
177
+ # We know these blocks are guaranteed to be "full", so we don't need to
178
+ # apply mask_mod to them - only score_mod
179
+ if HAS_FULL_BLOCKS:
180
+ # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
181
+ kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
182
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
183
+ kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
184
+ block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
185
+ # K and V pointers will be passed directly to forward_inner
186
+ offs_n = kv_start + tl.arange(0, BLOCK_N)
187
+
188
+ acc, l_i, m_i = forward_inner(
189
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1,
190
+ q, K, V,
191
+ desc_k, desc_v, Q_LEN, KV_LEN,
192
+ acc, l_i, m_i,
193
+ off_zq, off_hq, offs_m[:, None], offs_n[None, :],
194
+ kv_start,
195
+ kv_indices, kv_num_blocks,
196
+ 0, block_n_end,
197
+ MATMUL_PRECISION,
198
+ stride_kk, stride_kn, stride_vn, stride_vk,
199
+ IS_FULL_BLOCKS=True,
200
+ )
201
+
202
+
203
+ # [Note] Handle fully masked out rows:
204
+ # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf.
205
+ # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step
206
+ l_i = tl.where(l_i == 0.0, 1, l_i)
207
+
208
+ acc = acc / l_i[:, None]
209
+ idx_zq = tl.program_id(1).to(INDEX_DTYPE)
210
+ idx_hq = tl.program_id(2).to(INDEX_DTYPE)
211
+ idx_m = offs_m[:, None].to(INDEX_DTYPE)
212
+ idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE)
213
+
214
+ mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM)
215
+
216
+ tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED])
217
+ xindex = idx_d + 128*idx_m + 262144*idx_hq + 8388608*idx_zq
218
+ tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m + 8388608*idx_zq, acc.shape)), acc, mask)
219
+
220
+ if OUTPUT_LOGSUMEXP:
221
+ off_hz = off_zq * HQ + off_hq
222
+ l_ptrs = LSE + off_hz * Q_LEN + offs_m
223
+ lse = m_i + tl.math.log2(l_i)
224
+ if IS_DIVISIBLE:
225
+ tl.store(l_ptrs, lse)
226
+ else:
227
+ tl.store(l_ptrs, lse, mask=offs_m < Q_LEN)
228
+
229
+ if OUTPUT_MAX:
230
+ off_hz = off_zq * HQ + off_hq
231
+ max_ptrs = MAX + off_hz * Q_LEN + offs_m
232
+ if IS_DIVISIBLE:
233
+ tl.store(max_ptrs, m_i)
234
+ else:
235
+ tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN)
236
+
237
+
238
+ # Utility triton funcs
239
+ @triton.jit
240
+ def get_offset_for_next_block(
241
+ loop_iter, col_indices, total_blocks,
242
+ SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
243
+ BLOCKS_ARE_CONTIGUOUS: tl.constexpr
244
+ ):
245
+ if BLOCKS_ARE_CONTIGUOUS:
246
+ return BLOCK
247
+ cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
248
+ cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
249
+ next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
250
+ needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
251
+ jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
252
+ offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
253
+ return offset
254
+
255
+ @triton.jit
256
+ def get_bounded_indices(indices, max_len=None):
257
+ return indices % max_len if max_len is not None else indices
258
+
259
+ @triton.jit
260
+ def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
261
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
262
+ return tl.load(block_ptr)
263
+ elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
264
+ return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
265
+ elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
266
+ return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
267
+ else:
268
+ return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
269
+
270
+ @triton.jit
271
+ def load_checked_2d(
272
+ ptr,
273
+ offs_m,
274
+ offs_n,
275
+ stride_m,
276
+ stride_n,
277
+ IS_DIVISIBLE_M: tl.constexpr,
278
+ IS_DIVISIBLE_N: tl.constexpr,
279
+ M_LEN: tl.constexpr,
280
+ N_LEN: tl.constexpr,
281
+ ):
282
+ # Calculate final pointer if strides are provided
283
+ if stride_m is not None and stride_n is not None:
284
+ ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
285
+
286
+ # Handle all masking cases
287
+ if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
288
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
289
+ elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
290
+ return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
291
+ elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
292
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
293
+ else: # Both divisible
294
+ return tl.load(ptr)
295
+
296
+
297
+ # Common Imports
298
+ @triton.jit
299
+ def forward_block_mn(
300
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1,
301
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
302
+ # accumulated values
303
+ acc, l_i, m_i,
304
+ # Offsets
305
+ off_z, off_h, offs_m, offs_n,
306
+ # Offsets needed for TMA loads
307
+ kv_start,
308
+ kv_offset,
309
+ MATMUL_PRECISION, RCP_LN2,
310
+ # Strides for K and V
311
+ stride_kk, stride_kn, stride_vn, stride_vk,
312
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
313
+
314
+ ):
315
+ # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
316
+ PRESCALE_QK : tl.constexpr = False
317
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
318
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
319
+ WRITE_DQ : tl.constexpr = True
320
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
321
+ OUTPUT_MAX : tl.constexpr = False
322
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
323
+ IS_DIVISIBLE : tl.constexpr = False
324
+ SM_SCALE : tl.constexpr = 0.08838834764831843
325
+ GQA_SHARED_HEADS : tl.constexpr = 4
326
+ HAS_FULL_BLOCKS : tl.constexpr = True
327
+ QK_HEAD_DIM : tl.constexpr = 128
328
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
329
+ V_HEAD_DIM : tl.constexpr = 128
330
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
331
+ SAFE_HEAD_DIM : tl.constexpr = True
332
+ USE_TMA : tl.constexpr = False
333
+ BLOCK_M : tl.constexpr = 128
334
+ BLOCK_N : tl.constexpr = 64
335
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
336
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
337
+ INDEX_DTYPE : tl.constexpr = tl.int32
338
+
339
+
340
+ # -- load k --
341
+ # NB reversed order to since K is transposed
342
+ kv_base_offset = kv_start + kv_offset
343
+
344
+ # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N]
345
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
346
+ offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N)
347
+ k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
348
+
349
+ k = tl.trans(k)
350
+ # -- compute qk ---
351
+ qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2.
352
+ if not PRESCALE_QK:
353
+ qk *= SM_SCALE
354
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
355
+ # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements,
356
+ # which is larger than the actual number of elements. To avoid access memory out of bound,
357
+ # we need to mask out the elements that are out of Q_LEN & KV_LEN.
358
+ m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None)
359
+ n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None)
360
+
361
+ tmp0 = (qk)
362
+ post_mod_scores = tmp0
363
+
364
+
365
+ if CHECK_BLOCK_BOUNDARY:
366
+ # Mask out the elements that are out of the KV_LEN for non divisible seqlen.
367
+ post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf"))
368
+
369
+ if not IS_FULL_BLOCKS:
370
+ tmp1 = tl.full([1], False, tl.int1)
371
+ tmp2 = (m)
372
+ tmp3 = (n)
373
+ tmp4 = tmp2 >= tmp3
374
+ tmp5 = tmp3.to(tl.int64)
375
+ tmp6 = (off_z)
376
+ tmp7 = tl.load(in_ptr9 + tmp6)
377
+ tmp8 = tmp5 < tmp7
378
+ tmp9 = tmp2.to(tl.int64)
379
+ tmp10 = tmp9 < tmp7
380
+ tmp11 = tmp8 & tmp10
381
+ tmp12 = tmp4 & tmp11
382
+ tmp13 = tmp1 | tmp12
383
+ tmp14 = tl.full([1], 2048, tl.int32)
384
+ tmp15 = tmp3 >= tmp14
385
+ tmp16 = (tmp3 % tmp14)
386
+ tmp17 = tl.full([1], 0, tl.int32)
387
+ tmp18 = tmp16 != tmp17
388
+ tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0
389
+ tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0
390
+ tmp21 = tmp19 != tmp20
391
+ tmp22 = tmp18 & tmp21
392
+ tmp23 = tmp16 + tmp14
393
+ tmp24 = tl.where(tmp22, tmp23, tmp16)
394
+ tmp25 = tmp24.to(tl.int64)
395
+ tmp26 = tmp25 < tmp7
396
+ tmp27 = tmp15 & tmp26
397
+ tmp28 = tmp3 - tmp2
398
+ tmp29 = (tmp28 % tmp14)
399
+ tmp30 = tmp29 != tmp17
400
+ tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0
401
+ tmp32 = tmp31 != tmp20
402
+ tmp33 = tmp30 & tmp32
403
+ tmp34 = tmp29 + tmp14
404
+ tmp35 = tl.where(tmp33, tmp34, tmp29)
405
+ tmp36 = tmp35 == tmp17
406
+ tmp37 = tmp27 & tmp36
407
+ tmp38 = tmp13 | tmp37
408
+ mask_mod_output = tmp38
409
+
410
+
411
+ if CHECK_BLOCK_BOUNDARY:
412
+ mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False)
413
+ # apply mask for partially unmasked blocks
414
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
415
+
416
+ if not PRESCALE_QK:
417
+ post_mod_scores *= RCP_LN2
418
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
419
+
420
+ # -- compute scaling constant ---
421
+ m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1))
422
+ if not ROWS_GUARANTEED_SAFE:
423
+ masked_out_rows = (m_ij == float("-inf"))
424
+ m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
425
+ else:
426
+ m_ij_masked = m_ij
427
+
428
+ alpha = tl.math.exp2(m_i - m_ij_masked)
429
+ p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None])
430
+
431
+ # NB: l_i update is pulled up here since it's a bit faster
432
+ # NB: For headdim=256, it's faster to move it back down to after m_i =
433
+ # m_ij
434
+ l_i = l_i * alpha + tl.sum(p, 1)
435
+ # # -- scale and update acc --
436
+ acc = acc * alpha[:, None]
437
+ # Calculate offsets for V loading - reuse kv_base_offset from K loading
438
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
439
+ v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
440
+ acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION)
441
+
442
+ # -- update m_i
443
+ m_i = m_ij
444
+
445
+ return acc, l_i, m_i
446
+
447
+ @triton.jit
448
+ def forward_inner(
449
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1,
450
+ q, K, V,
451
+ desc_k, desc_v, Q_LEN, KV_LEN,
452
+ # accumulated values
453
+ acc, l_i, m_i,
454
+ # Offsets used as inputs to score_mod & mask_mod
455
+ # of size [BLOCK_M, BLOCK_N] or scalar.
456
+ off_z, off_h, offs_m, offs_n,
457
+ # Offsets needed for TMA loads
458
+ kv_start,
459
+ # blocksparse data
460
+ kv_indices, kv_num_blocks,
461
+ # start kv and end kv block
462
+ block_n_start, block_n_end,
463
+ MATMUL_PRECISION,
464
+ # Strides for K and V
465
+ stride_kk, stride_kn, stride_vn, stride_vk,
466
+ IS_FULL_BLOCKS,
467
+ ):
468
+ # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
469
+ PRESCALE_QK : tl.constexpr = False
470
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
471
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
472
+ WRITE_DQ : tl.constexpr = True
473
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
474
+ OUTPUT_MAX : tl.constexpr = False
475
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
476
+ IS_DIVISIBLE : tl.constexpr = False
477
+ SM_SCALE : tl.constexpr = 0.08838834764831843
478
+ GQA_SHARED_HEADS : tl.constexpr = 4
479
+ HAS_FULL_BLOCKS : tl.constexpr = True
480
+ QK_HEAD_DIM : tl.constexpr = 128
481
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
482
+ V_HEAD_DIM : tl.constexpr = 128
483
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
484
+ SAFE_HEAD_DIM : tl.constexpr = True
485
+ USE_TMA : tl.constexpr = False
486
+ BLOCK_M : tl.constexpr = 128
487
+ BLOCK_N : tl.constexpr = 64
488
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
489
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
490
+ INDEX_DTYPE : tl.constexpr = tl.int32
491
+
492
+
493
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
494
+ RCP_LN2: tl.constexpr = 1.44269504
495
+
496
+ if PRESCALE_QK:
497
+ q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
498
+
499
+ kv_offset = 0
500
+
501
+ # loop over k, v and update accumulator until block_n_end
502
+ for start_n in range(block_n_start, block_n_end):
503
+ # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention.
504
+ if IS_DIVISIBLE:
505
+ acc, l_i, m_i = forward_block_mn(
506
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1,
507
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
508
+ # accumulated values
509
+ acc, l_i, m_i,
510
+ # Offsets
511
+ off_z, off_h, offs_m, offs_n,
512
+ # Offsets needed for TMA loads
513
+ kv_start,
514
+ kv_offset,
515
+ MATMUL_PRECISION, RCP_LN2,
516
+ # Strides for K and V
517
+ stride_kk, stride_kn, stride_vn, stride_vk,
518
+ IS_FULL_BLOCKS,
519
+ )
520
+ else:
521
+ # Benchmark shows even we applied mod & mask to each block for non divisible seqlen,
522
+ # it's on par or slightly faster than only applying to the last block in fwd.
523
+ # However, we choose different strategy for bwd, where we only apply mod & mask
524
+ # to the last block because it's faster a lot.
525
+ acc, l_i, m_i = forward_block_mn(
526
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1,
527
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
528
+ # accumulated values
529
+ acc, l_i, m_i,
530
+ # Offsets
531
+ off_z, off_h, offs_m, offs_n,
532
+ # Offsets needed for TMA loads
533
+ kv_start,
534
+ kv_offset,
535
+ MATMUL_PRECISION, RCP_LN2,
536
+ # Strides for K and V
537
+ stride_kk, stride_kn, stride_vn, stride_vk,
538
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
539
+ )
540
+
541
+
542
+
543
+ offset = get_offset_for_next_block(
544
+ start_n, kv_indices, kv_num_blocks,
545
+ SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS
546
+ )
547
+
548
+ offs_n = offs_n + offset
549
+ kv_offset += offset
550
+
551
+
552
+ return acc, l_i, m_i
SpecForge-ext/cache/compiled_kernels/5h/c5h6tol66uk77tfumu3xd25ecbr6kkxkqgk3zbmjpk4tc6sikmjb.py ADDED
@@ -0,0 +1,543 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AOT ID: ['5_inference']
2
+ from ctypes import c_void_p, c_long, c_int
3
+ import torch
4
+ import math
5
+ import random
6
+ import os
7
+ import tempfile
8
+ from math import inf, nan
9
+ from cmath import nanj
10
+ from torch._inductor.hooks import run_intermediate_hooks
11
+ from torch._inductor.utils import maybe_profile
12
+ from torch._inductor.codegen.memory_planning import _align as align
13
+ from torch import device, empty_strided
14
+ from torch._inductor.async_compile import AsyncCompile
15
+ from torch._inductor.select_algorithm import extern_kernels
16
+ import triton
17
+ import triton.language as tl
18
+ from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
19
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
20
+
21
+ aten = torch.ops.aten
22
+ inductor_ops = torch.ops.inductor
23
+ _quantized = torch.ops._quantized
24
+ assert_size_stride = torch._C._dynamo.guards.assert_size_stride
25
+ assert_alignment = torch._C._dynamo.guards.assert_alignment
26
+ empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
27
+ empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
28
+ empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
29
+ empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
30
+ empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
31
+ reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
32
+ alloc_from_pool = torch.ops.inductor._alloc_from_pool
33
+ async_compile = AsyncCompile()
34
+ empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
35
+
36
+
37
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/jh/cjhd7kndnunfa7ikwg3gxzzxuods7fnn5vlwqbhjxnla3dldi6sq.py
38
+ # Topologically Sorted Source Nodes: [result_1, m, causal_mask, n, b, index, lt, padding_mask, index_1, lt_1, and_2, suffix_mask, remainder, index_2, padding_mask_1, and_3, and_4, sub, remainder_1, diagnol_mask, result_2, batched_outputs_2, mask_2, mask_3, mask_block_sum], Original ATen: [aten.view, aten.arange, aten.ge, aten.index, aten.lt, aten.bitwise_and, aten.bitwise_or, aten.remainder, aten.sub, aten.eq, aten.permute, aten.sum]
39
+ # Source node to ATen node mapping:
40
+ # and_2 => bitwise_and_1
41
+ # and_3 => bitwise_and_2
42
+ # and_4 => bitwise_and_3, view_8
43
+ # b => iota
44
+ # batched_outputs_2 => view_9
45
+ # causal_mask => ge, view
46
+ # diagnol_mask => eq
47
+ # index => index
48
+ # index_1 => index_1
49
+ # index_2 => index_2
50
+ # lt => lt, view_1
51
+ # lt_1 => lt_1, view_2
52
+ # m => iota_2
53
+ # mask_2 => view_10
54
+ # mask_3 => permute
55
+ # mask_block_sum => sum_1
56
+ # n => iota_3
57
+ # padding_mask => bitwise_and, view_3, view_4
58
+ # padding_mask_1 => lt_2, view_6
59
+ # remainder => remainder
60
+ # remainder_1 => remainder_1
61
+ # result_1 => bitwise_or, full_default
62
+ # result_2 => bitwise_or_1
63
+ # sub => sub, view_7
64
+ # suffix_mask => ge_1
65
+ # Graph fragment:
66
+ # %arg0_1 : Tensor "i64[8][1]cuda:1" = PlaceHolder[target=arg0_1]
67
+ # %full_default : Tensor "b8[8, 1, 1][1, 1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 1, 1], False), kwargs = {dtype: torch.bool, layout: torch.strided, device: cuda:1, pin_memory: False})
68
+ # %iota_2 : Tensor "i64[2048][1]cuda:1"[num_users=3] = call_function[target=torch.ops.prims.iota.default](args = (2048,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:1, requires_grad: False})
69
+ # %view : Tensor "i64[2048, 1][1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%iota_2, [2048, 1]), kwargs = {})
70
+ # %iota_3 : Tensor "i64[2048][1]cuda:1"[num_users=5] = call_function[target=torch.ops.prims.iota.default](args = (2048,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:1, requires_grad: False})
71
+ # %ge : Tensor "b8[2048, 2048][2048, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.ge.Tensor](args = (%view, %iota_3), kwargs = {})
72
+ # %iota : Tensor "i64[8][1]cuda:1"[num_users=3] = call_function[target=torch.ops.prims.iota.default](args = (8,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:1, requires_grad: False})
73
+ # %index : Tensor "i64[8][1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg0_1, [%iota]), kwargs = {})
74
+ # %view_1 : Tensor "i64[8, 1][1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%index, [8, 1]), kwargs = {})
75
+ # %lt : Tensor "b8[8, 2048][2048, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_3, %view_1), kwargs = {})
76
+ # %view_4 : Tensor "b8[8, 1, 2048][2048, 2048, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%lt, [8, 1, 2048]), kwargs = {})
77
+ # %index_1 : Tensor "i64[8][1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg0_1, [%iota]), kwargs = {})
78
+ # %view_2 : Tensor "i64[8, 1][1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%index_1, [8, 1]), kwargs = {})
79
+ # %lt_1 : Tensor "b8[8, 2048][2048, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_2, %view_2), kwargs = {})
80
+ # %view_3 : Tensor "b8[8, 2048, 1][2048, 1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%lt_1, [8, 2048, 1]), kwargs = {})
81
+ # %bitwise_and : Tensor "b8[8, 2048, 2048][4194304, 2048, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%view_4, %view_3), kwargs = {})
82
+ # %bitwise_and_1 : Tensor "b8[8, 2048, 2048][4194304, 2048, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%ge, %bitwise_and), kwargs = {})
83
+ # %bitwise_or : Tensor "b8[8, 2048, 2048][4194304, 2048, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.bitwise_or.Tensor](args = (%full_default, %bitwise_and_1), kwargs = {})
84
+ # %ge_1 : Tensor "b8[2048][1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.ge.Scalar](args = (%iota_3, 2048), kwargs = {})
85
+ # %remainder : Tensor "i64[2048][1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.remainder.Scalar](args = (%iota_3, 2048), kwargs = {})
86
+ # %index_2 : Tensor "i64[8][1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg0_1, [%iota]), kwargs = {})
87
+ # %view_6 : Tensor "i64[8, 1][1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%index_2, [8, 1]), kwargs = {})
88
+ # %lt_2 : Tensor "b8[8, 2048][2048, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%remainder, %view_6), kwargs = {})
89
+ # %bitwise_and_2 : Tensor "b8[8, 2048][2048, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%ge_1, %lt_2), kwargs = {})
90
+ # %view_8 : Tensor "b8[8, 1, 2048][2048, 2048, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%bitwise_and_2, [8, 1, 2048]), kwargs = {})
91
+ # %view_7 : Tensor "i64[2048, 1][1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%iota_2, [2048, 1]), kwargs = {})
92
+ # %sub : Tensor "i64[2048, 2048][2048, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%iota_3, %view_7), kwargs = {})
93
+ # %remainder_1 : Tensor "i64[2048, 2048][2048, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.remainder.Scalar](args = (%sub, 2048), kwargs = {})
94
+ # %eq : Tensor "b8[2048, 2048][2048, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%remainder_1, 0), kwargs = {})
95
+ # %bitwise_and_3 : Tensor "b8[8, 2048, 2048][4194304, 2048, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%view_8, %eq), kwargs = {})
96
+ # %bitwise_or_1 : Tensor "b8[8, 2048, 2048][4194304, 2048, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.bitwise_or.Tensor](args = (%bitwise_or, %bitwise_and_3), kwargs = {})
97
+ # %view_9 : Tensor "b8[8, 1, 2048, 2048][4194304, 4194304, 2048, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%bitwise_or_1, [8, 1, 2048, 2048]), kwargs = {})
98
+ # %view_10 : Tensor "b8[8, 1, 16, 128, 16, 128][4194304, 4194304, 262144, 2048, 128, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%expand, [8, 1, 16, 128, 16, 128]), kwargs = {})
99
+ # %permute : Tensor "b8[8, 1, 16, 16, 128, 128][4194304, 4194304, 262144, 128, 2048, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_10, [0, 1, 2, 4, 3, 5]), kwargs = {})
100
+ # %sum_1 : Tensor "i64[8, 1, 16, 16][256, 256, 16, 1]cuda:1"[num_users=3] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%permute, [-2, -1]), kwargs = {})
101
+ # return %sum_1
102
+ triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0 = async_compile.triton('triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0', '''
103
+ import triton
104
+ import triton.language as tl
105
+
106
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
107
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
108
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
109
+ triton_helpers.set_driver_to_gpu()
110
+
111
+ @triton_heuristics.reduction(
112
+ size_hints={'x': 2048, 'r0_': 16384},
113
+ reduction_hint=ReductionHint.INNER,
114
+ filename=__file__,
115
+ triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
116
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 32768, 'r0_': 0}}
117
+ )
118
+ @triton.jit
119
+ def triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
120
+ xnumel = 2048
121
+ r0_numel = 16384
122
+ rnumel = r0_numel
123
+ RBLOCK: tl.constexpr = R0_BLOCK
124
+ xoffset = tl.program_id(0) * XBLOCK
125
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
126
+ xmask = xindex < xnumel
127
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
128
+ rbase = r0_base
129
+ x1 = ((xindex // 16) % 16)
130
+ x0 = (xindex % 16)
131
+ x2 = xindex // 256
132
+ tmp3 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last')
133
+ _tmp29 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64)
134
+ x6 = xindex
135
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
136
+ r0_index = r0_offset + r0_base
137
+ r0_mask = r0_index < r0_numel
138
+ roffset = r0_offset
139
+ rindex = r0_index
140
+ r0_4 = r0_index // 128
141
+ r0_3 = (r0_index % 128)
142
+ tmp0 = r0_4 + 128*x1
143
+ tmp1 = r0_3 + 128*x0
144
+ tmp2 = tmp0 >= tmp1
145
+ tmp4 = tmp1 < tmp3
146
+ tmp5 = tmp0 < tmp3
147
+ tmp6 = tmp4 & tmp5
148
+ tmp7 = tmp2 & tmp6
149
+ tmp8 = tl.full([1, 1], False, tl.int1)
150
+ tmp9 = tmp8 | tmp7
151
+ tmp10 = tl.full([1, 1], 2048, tl.int64)
152
+ tmp11 = tmp1 >= tmp10
153
+ tmp12 = tmp11 & tmp4
154
+ tmp13 = r0_3 + ((-1)*r0_4) + ((-128)*x1) + 128*x0
155
+ tmp14 = (tmp13 % tmp10)
156
+ tmp15 = tl.full([1, 1], 0, tl.int32)
157
+ tmp16 = tmp14 != tmp15
158
+ tmp17 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0
159
+ tmp18 = (libdevice.signbit(tmp10) != 0) if (tmp10).dtype is tl.float32 else tmp10 < 0
160
+ tmp19 = tmp17 != tmp18
161
+ tmp20 = tmp16 & tmp19
162
+ tmp21 = tmp14 + tmp10
163
+ tmp22 = tl.where(tmp20, tmp21, tmp14)
164
+ tmp23 = tl.full([1, 1], 0, tl.int64)
165
+ tmp24 = tmp22 == tmp23
166
+ tmp25 = tmp12 & tmp24
167
+ tmp26 = tmp9 | tmp25
168
+ tmp27 = tmp26.to(tl.int64)
169
+ tmp28 = tl.broadcast_to(tmp27, [XBLOCK, R0_BLOCK])
170
+ tmp30 = _tmp29 + tmp28
171
+ _tmp29 = tl.where(r0_mask & xmask, tmp30, _tmp29)
172
+ tmp29 = tl.sum(_tmp29, 1)[:, None]
173
+ tl.store(out_ptr0 + (x6), tmp29, xmask)
174
+ ''', device_str='cuda')
175
+
176
+
177
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/6f/c6fuhct5vdp3d5lx45chz27ghag5dfreh2h3hbzxl5elhim3qhpx.py
178
+ # Topologically Sorted Source Nodes: [dense_mask_4], Original ATen: [aten.new_zeros]
179
+ # Source node to ATen node mapping:
180
+ # dense_mask_4 => full_default_4
181
+ # Graph fragment:
182
+ # %full_default_4 : Tensor "i32[8, 1, 16, 17][272, 272, 17, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 1, 16, 17], 0), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:1, pin_memory: False})
183
+ # return %index_put_1
184
+ triton_poi_fused_new_zeros_1 = async_compile.triton('triton_poi_fused_new_zeros_1', '''
185
+ import triton
186
+ import triton.language as tl
187
+
188
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
189
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
190
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
191
+ triton_helpers.set_driver_to_gpu()
192
+
193
+ @triton_heuristics.pointwise(
194
+ size_hints={'x': 4096},
195
+ filename=__file__,
196
+ triton_meta={'signature': {'out_ptr0': '*i32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]},
197
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_new_zeros_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 17408}},
198
+ min_elem_per_thread=0
199
+ )
200
+ @triton.jit
201
+ def triton_poi_fused_new_zeros_1(out_ptr0, xnumel, XBLOCK : tl.constexpr):
202
+ xnumel = 2176
203
+ xoffset = tl.program_id(0) * XBLOCK
204
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
205
+ xmask = xindex < xnumel
206
+ x0 = xindex
207
+ tmp0 = tl.full([1], 0, tl.int32)
208
+ tl.store(out_ptr0 + (x0), tmp0, xmask)
209
+ ''', device_str='cuda')
210
+
211
+
212
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/c2/cc2qlkbbemfommyywsdbow3sqg7jqf5x5tfkbqjzo2qy6lt36yjr.py
213
+ # Topologically Sorted Source Nodes: [gt, lt_3, partial_blocks, partial_blocks_1, dense_mask, col_indices, full_blocks, full_blocks_1, dense_mask_1, col_indices_1, dense_mask_2, setitem, arange_4, row_indices, col_range, num_blocks_in_row, child_3, unsqueeze_1, index_mask, child_4, valid_indices, dense_mask_4, setitem_1, arange_6, row_indices_1, col_range_1, num_blocks_in_row_1, child_7, unsqueeze_3, index_mask_1, child_8, valid_indices_1], Original ATen: [aten.gt, aten.lt, aten.bitwise_and, aten._to_copy, aten.sort, aten.eq, aten.new_zeros, aten.arange, aten.unsqueeze, aten.sum, aten.scalar_tensor, aten.where, aten.view, aten.index_put]
214
+ # Source node to ATen node mapping:
215
+ # arange_4 => iota_4
216
+ # arange_6 => iota_8
217
+ # child_3 => convert_element_type_3
218
+ # child_4 => convert_element_type_4
219
+ # child_7 => convert_element_type_6
220
+ # child_8 => convert_element_type_7
221
+ # col_indices => sort
222
+ # col_indices_1 => sort_1
223
+ # col_range => iota_5
224
+ # col_range_1 => iota_9
225
+ # dense_mask => convert_element_type_2
226
+ # dense_mask_1 => convert_element_type_5
227
+ # dense_mask_2 => full_default_1
228
+ # dense_mask_4 => full_default_4
229
+ # full_blocks => eq_1
230
+ # full_blocks_1 => convert_element_type_1
231
+ # gt => gt
232
+ # index_mask => lt_4
233
+ # index_mask_1 => lt_5
234
+ # lt_3 => lt_3
235
+ # num_blocks_in_row => sum_2
236
+ # num_blocks_in_row_1 => sum_3
237
+ # partial_blocks => bitwise_and_4
238
+ # partial_blocks_1 => convert_element_type
239
+ # row_indices => unsqueeze
240
+ # row_indices_1 => unsqueeze_7
241
+ # setitem => full_default_3, index_put, iota_6, iota_7, unsqueeze_2, unsqueeze_3, unsqueeze_4, unsqueeze_5, unsqueeze_6
242
+ # setitem_1 => full_default_6, index_put_1, iota_10, iota_11, unsqueeze_10, unsqueeze_11, unsqueeze_12, unsqueeze_13, unsqueeze_9
243
+ # unsqueeze_1 => unsqueeze_1
244
+ # unsqueeze_3 => unsqueeze_8
245
+ # valid_indices => full_default_2, where
246
+ # valid_indices_1 => full_default_5, where_1
247
+ # Graph fragment:
248
+ # %sum_1 : Tensor "i64[8, 1, 16, 16][256, 2048, 16, 1]cuda:1" = PlaceHolder[target=sum_1]
249
+ # %sum_2 : Tensor "i64[8, 1, 16][16, 128, 1]cuda:1" = PlaceHolder[target=sum_2]
250
+ # %sum_3 : Tensor "i64[8, 1, 16][16, 128, 1]cuda:1" = PlaceHolder[target=sum_3]
251
+ # %buf2 : Tensor "i16[8, 1, 16, 16][256, 2048, 16, 1]cuda:1" = PlaceHolder[target=buf2]
252
+ # %convert_element_type_3 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:1" = PlaceHolder[target=convert_element_type_3]
253
+ # %convert_element_type_4 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:1" = PlaceHolder[target=convert_element_type_4]
254
+ # %index_put : Tensor "i32[8, 1, 16, 17][272, 272, 17, 1]cuda:1" = PlaceHolder[target=index_put]
255
+ # %buf4 : Tensor "i16[8, 1, 16, 16][256, 2048, 16, 1]cuda:1" = PlaceHolder[target=buf4]
256
+ # %convert_element_type_6 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:1" = PlaceHolder[target=convert_element_type_6]
257
+ # %convert_element_type_7 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:1" = PlaceHolder[target=convert_element_type_7]
258
+ # %index_put_1 : Tensor "i32[8, 1, 16, 17][272, 272, 17, 1]cuda:1" = PlaceHolder[target=index_put_1]
259
+ # %gt : Tensor "b8[8, 1, 16, 16][256, 256, 16, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {})
260
+ # %lt_3 : Tensor "b8[8, 1, 16, 16][256, 256, 16, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.lt.Scalar](args = (%sum_1, 16384), kwargs = {})
261
+ # %bitwise_and_4 : Tensor "b8[8, 1, 16, 16][256, 256, 16, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%gt, %lt_3), kwargs = {})
262
+ # %convert_element_type : Tensor "i8[8, 1, 16, 16][256, 256, 16, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%bitwise_and_4, torch.int8), kwargs = {})
263
+ # %convert_element_type_2 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:1"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%convert_element_type, torch.int32), kwargs = {})
264
+ # %sort : [num_users=1] = call_function[target=torch.ops.aten.sort.stable](args = (%convert_element_type_2,), kwargs = {stable: True, descending: True})
265
+ # %eq_1 : Tensor "b8[8, 1, 16, 16][256, 256, 16, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%sum_1, 16384), kwargs = {})
266
+ # %convert_element_type_1 : Tensor "i8[8, 1, 16, 16][256, 256, 16, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%eq_1, torch.int8), kwargs = {})
267
+ # %convert_element_type_5 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:1"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%convert_element_type_1, torch.int32), kwargs = {})
268
+ # %sort_1 : [num_users=1] = call_function[target=torch.ops.aten.sort.stable](args = (%convert_element_type_5,), kwargs = {stable: True, descending: True})
269
+ # %full_default_1 : Tensor "i32[8, 1, 16, 17][272, 272, 17, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 1, 16, 17], 0), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:1, pin_memory: False})
270
+ # %iota_7 : Tensor "i64[8][1]cuda:1"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (8,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:1, requires_grad: False})
271
+ # %unsqueeze_4 : Tensor "i64[8, 1][1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_7, -1), kwargs = {})
272
+ # %unsqueeze_5 : Tensor "i64[8, 1, 1][1, 1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_4, -1), kwargs = {})
273
+ # %unsqueeze_6 : Tensor "i64[8, 1, 1, 1][1, 1, 1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_5, -1), kwargs = {})
274
+ # %iota_6 : Tensor "i64[1][1]cuda:1"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (1,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:1, requires_grad: False})
275
+ # %unsqueeze_2 : Tensor "i64[1, 1][1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_6, -1), kwargs = {})
276
+ # %unsqueeze_3 : Tensor "i64[1, 1, 1][1, 1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_2, -1), kwargs = {})
277
+ # %iota_4 : Tensor "i32[16][1]cuda:1"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (16,), kwargs = {start: 0, step: 1, dtype: torch.int32, device: cuda:1, requires_grad: False})
278
+ # %unsqueeze : Tensor "i32[16, 1][1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_4, -1), kwargs = {})
279
+ # %iota_5 : Tensor "i32[16][1]cuda:1"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (16,), kwargs = {start: 0, step: 1, dtype: torch.int32, device: cuda:1, requires_grad: False})
280
+ # %sum_2 : Tensor "i64[8, 1, 16][16, 16, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%convert_element_type_2, [-1]), kwargs = {})
281
+ # %convert_element_type_3 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:1"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_2, torch.int32), kwargs = {})
282
+ # %unsqueeze_1 : Tensor "i32[8, 1, 16, 1][16, 16, 1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_3, 3), kwargs = {})
283
+ # %lt_4 : Tensor "b8[8, 1, 16, 16][256, 256, 16, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_5, %unsqueeze_1), kwargs = {})
284
+ # %convert_element_type_4 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:1"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_1, torch.int32), kwargs = {})
285
+ # %full_default_2 : Tensor "i32[][]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([], 16), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:1, pin_memory: False})
286
+ # %where : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%lt_4, %convert_element_type_4, %full_default_2), kwargs = {})
287
+ # %full_default_3 : Tensor "i32[8, 1, 1, 1][1, 1, 1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 1, 1, 1], 1), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:1, pin_memory: False})
288
+ # %index_put : Tensor "i32[8, 1, 16, 17][272, 272, 17, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.index_put_.default](args = (%full_default_1, [%unsqueeze_6, %unsqueeze_3, %unsqueeze, %where], %full_default_3), kwargs = {})
289
+ # %full_default_4 : Tensor "i32[8, 1, 16, 17][272, 272, 17, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 1, 16, 17], 0), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:1, pin_memory: False})
290
+ # %iota_11 : Tensor "i64[8][1]cuda:1"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (8,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:1, requires_grad: False})
291
+ # %unsqueeze_11 : Tensor "i64[8, 1][1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_11, -1), kwargs = {})
292
+ # %unsqueeze_12 : Tensor "i64[8, 1, 1][1, 1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_11, -1), kwargs = {})
293
+ # %unsqueeze_13 : Tensor "i64[8, 1, 1, 1][1, 1, 1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_12, -1), kwargs = {})
294
+ # %iota_10 : Tensor "i64[1][1]cuda:1"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (1,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:1, requires_grad: False})
295
+ # %unsqueeze_9 : Tensor "i64[1, 1][1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_10, -1), kwargs = {})
296
+ # %unsqueeze_10 : Tensor "i64[1, 1, 1][1, 1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_9, -1), kwargs = {})
297
+ # %iota_8 : Tensor "i32[16][1]cuda:1"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (16,), kwargs = {start: 0, step: 1, dtype: torch.int32, device: cuda:1, requires_grad: False})
298
+ # %unsqueeze_7 : Tensor "i32[16, 1][1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_8, -1), kwargs = {})
299
+ # %iota_9 : Tensor "i32[16][1]cuda:1"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (16,), kwargs = {start: 0, step: 1, dtype: torch.int32, device: cuda:1, requires_grad: False})
300
+ # %sum_3 : Tensor "i64[8, 1, 16][16, 16, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%convert_element_type_5, [-1]), kwargs = {})
301
+ # %convert_element_type_6 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:1"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_3, torch.int32), kwargs = {})
302
+ # %unsqueeze_8 : Tensor "i32[8, 1, 16, 1][16, 16, 1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_6, 3), kwargs = {})
303
+ # %lt_5 : Tensor "b8[8, 1, 16, 16][256, 256, 16, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_9, %unsqueeze_8), kwargs = {})
304
+ # %convert_element_type_7 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:1"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_3, torch.int32), kwargs = {})
305
+ # %full_default_5 : Tensor "i32[][]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([], 16), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:1, pin_memory: False})
306
+ # %where_1 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%lt_5, %convert_element_type_7, %full_default_5), kwargs = {})
307
+ # %full_default_6 : Tensor "i32[8, 1, 1, 1][1, 1, 1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 1, 1, 1], 1), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:1, pin_memory: False})
308
+ # %index_put_1 : Tensor "i32[8, 1, 16, 17][272, 272, 17, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.index_put_.default](args = (%full_default_4, [%unsqueeze_13, %unsqueeze_10, %unsqueeze_7, %where_1], %full_default_6), kwargs = {})
309
+ # return %buf2,%buf4,%sum_2,%sum_3,%convert_element_type_3,%convert_element_type_6,%convert_element_type_4,%buf9,%convert_element_type_7,%buf16
310
+ triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2 = async_compile.triton('triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2', '''
311
+ import triton
312
+ import triton.language as tl
313
+
314
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
315
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
316
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
317
+ triton_helpers.set_driver_to_gpu()
318
+
319
+ @triton_heuristics.persistent_reduction(
320
+ size_hints={'x': 128, 'r0_': 16},
321
+ reduction_hint=ReductionHint.DEFAULT,
322
+ filename=__file__,
323
+ triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr4': '*i32', 'out_ptr5': '*i32', 'out_ptr6': '*i32', 'out_ptr7': '*i32', 'out_ptr8': '*i32', 'out_ptr9': '*i32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]},
324
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2', 'mutated_arg_names': ['out_ptr7', 'out_ptr9'], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
325
+ )
326
+ @triton.jit
327
+ def triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2(in_ptr0, out_ptr4, out_ptr5, out_ptr6, out_ptr7, out_ptr8, out_ptr9, xnumel, r0_numel, XBLOCK : tl.constexpr):
328
+ xnumel = 128
329
+ r0_numel = 16
330
+ R0_BLOCK: tl.constexpr = 16
331
+ rnumel = r0_numel
332
+ RBLOCK: tl.constexpr = R0_BLOCK
333
+ xoffset = tl.program_id(0) * XBLOCK
334
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
335
+ xmask = xindex < xnumel
336
+ r0_index = tl.arange(0, R0_BLOCK)[None, :]
337
+ r0_offset = 0
338
+ r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
339
+ roffset = r0_offset
340
+ rindex = r0_index
341
+ r0_1 = r0_index
342
+ x0 = xindex
343
+ tmp0 = tl.load(in_ptr0 + (r0_1 + 16*x0), xmask, other=0.0)
344
+ tmp1 = tl.full([1, 1], 0, tl.int64)
345
+ tmp2 = tmp0 > tmp1
346
+ tmp3 = tl.full([1, 1], 16384, tl.int64)
347
+ tmp4 = tmp0 < tmp3
348
+ tmp5 = tmp2 & tmp4
349
+ tmp6 = tmp5.to(tl.int8)
350
+ tmp7 = tmp6.to(tl.int32)
351
+ tmp8 = r0_1
352
+ tmp9 = tmp8.to(tl.int16)
353
+ tmp10 = tl.broadcast_to(tmp7, [XBLOCK, R0_BLOCK])
354
+ tmp11 = tl.broadcast_to(tmp9, [XBLOCK, R0_BLOCK])
355
+ tmp12, tmp13, = triton_helpers.sort_with_index(tmp10, tmp11, None, 1, stable=True, descending=True)
356
+ tmp14 = tmp0 == tmp3
357
+ tmp15 = tmp14.to(tl.int8)
358
+ tmp16 = tmp15.to(tl.int32)
359
+ tmp17 = tl.broadcast_to(tmp16, [XBLOCK, R0_BLOCK])
360
+ tmp18, tmp19, = triton_helpers.sort_with_index(tmp17, tmp11, None, 1, stable=True, descending=True)
361
+ tmp20 = tmp7.to(tl.int64)
362
+ tmp21 = tl.broadcast_to(tmp20, [XBLOCK, R0_BLOCK])
363
+ tmp23 = tl.where(xmask, tmp21, 0)
364
+ tmp24 = tl.sum(tmp23, 1)[:, None].to(tl.int64)
365
+ tmp25 = tmp16.to(tl.int64)
366
+ tmp26 = tl.broadcast_to(tmp25, [XBLOCK, R0_BLOCK])
367
+ tmp28 = tl.where(xmask, tmp26, 0)
368
+ tmp29 = tl.sum(tmp28, 1)[:, None].to(tl.int64)
369
+ tmp30 = tmp24.to(tl.int32)
370
+ tmp31 = tmp29.to(tl.int32)
371
+ tmp32 = tmp13.to(tl.int64)
372
+ tmp33 = tmp32.to(tl.int32)
373
+ tmp34 = tmp8 < tmp30
374
+ tmp35 = tl.full([1, 1], 16, tl.int32)
375
+ tmp36 = tl.where(tmp34, tmp33, tmp35)
376
+ tmp37 = tl.full([XBLOCK, R0_BLOCK], 17, tl.int32)
377
+ tmp38 = tmp36 + tmp37
378
+ tmp39 = tmp36 < 0
379
+ tmp40 = tl.where(tmp39, tmp38, tmp36)
380
+ tl.device_assert(((0 <= tmp40) & (tmp40 < 17)) | ~(xmask), "index out of bounds: 0 <= tmp40 < 17")
381
+ tmp42 = tl.full([1, 1], 1, tl.int32)
382
+ tmp43 = tmp19.to(tl.int64)
383
+ tmp44 = tmp43.to(tl.int32)
384
+ tmp45 = tmp8 < tmp31
385
+ tmp46 = tl.where(tmp45, tmp44, tmp35)
386
+ tmp47 = tmp46 + tmp37
387
+ tmp48 = tmp46 < 0
388
+ tmp49 = tl.where(tmp48, tmp47, tmp46)
389
+ tl.device_assert(((0 <= tmp49) & (tmp49 < 17)) | ~(xmask), "index out of bounds: 0 <= tmp49 < 17")
390
+ tl.store(out_ptr4 + (x0), tmp30, xmask)
391
+ tl.store(out_ptr5 + (x0), tmp31, xmask)
392
+ tl.store(out_ptr6 + (r0_1 + 16*x0), tmp33, xmask)
393
+ tl.store(out_ptr7 + (tl.broadcast_to(tmp40 + 17*x0, [XBLOCK, R0_BLOCK])), tmp42, xmask)
394
+ tl.store(out_ptr8 + (r0_1 + 16*x0), tmp44, xmask)
395
+ tl.store(out_ptr9 + (tl.broadcast_to(tmp49 + 17*x0, [XBLOCK, R0_BLOCK])), tmp42, xmask)
396
+ ''', device_str='cuda')
397
+
398
+
399
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/nj/cnjktwj7h4iwx4zghbum5atne46yt4ce4t5jnkkvyag35pn7glnh.py
400
+ # Topologically Sorted Source Nodes: [batched_outputs_3, transpose, col_indices_2, q_indices, num_blocks_in_row_2, q_num_blocks], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sort, aten._to_copy, aten.sum]
401
+ # Source node to ATen node mapping:
402
+ # batched_outputs_3 => clone_4, slice_2
403
+ # col_indices_2 => sort_2
404
+ # num_blocks_in_row_2 => sum_4
405
+ # q_indices => clone_6, convert_element_type_9
406
+ # q_num_blocks => convert_element_type_8
407
+ # transpose => permute_1
408
+ # Graph fragment:
409
+ # %buf9 : Tensor "i32[8, 1, 16, 17][272, 272, 17, 1]cuda:1" = PlaceHolder[target=buf9]
410
+ # %buf11 : Tensor "i16[8, 1, 16, 16][256, 2048, 16, 1]cuda:1" = PlaceHolder[target=buf11]
411
+ # %sum_4 : Tensor "i64[8, 1, 16][16, 128, 1]cuda:1" = PlaceHolder[target=sum_4]
412
+ # %slice_2 : Tensor "i32[8, 1, 16, 16][272, 272, 17, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%index_put, 3, 0, 16), kwargs = {})
413
+ # %clone_4 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%slice_2,), kwargs = {memory_format: torch.contiguous_format})
414
+ # %permute_1 : Tensor "i32[8, 1, 16, 16][256, 256, 1, 16]cuda:1"[num_users=2] = call_function[target=torch.ops.aten.permute.default](args = (%clone_4, [0, 1, 3, 2]), kwargs = {})
415
+ # %sort_2 : [num_users=1] = call_function[target=torch.ops.aten.sort.stable](args = (%permute_1,), kwargs = {stable: True, descending: True})
416
+ # %convert_element_type_9 : Tensor "i32[8, 1, 16, 16][256, 256, 1, 16]cuda:1"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_5, torch.int32), kwargs = {})
417
+ # %clone_6 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%convert_element_type_9,), kwargs = {memory_format: torch.contiguous_format})
418
+ # %sum_4 : Tensor "i64[8, 1, 16][16, 16, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%permute_1, [-1]), kwargs = {})
419
+ # %convert_element_type_8 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_4, torch.int32), kwargs = {})
420
+ # return %buf11,%sum_4,%clone_6,%convert_element_type_8
421
+ triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3 = async_compile.triton('triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3', '''
422
+ import triton
423
+ import triton.language as tl
424
+
425
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
426
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
427
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
428
+ triton_helpers.set_driver_to_gpu()
429
+
430
+ @triton_heuristics.persistent_reduction(
431
+ size_hints={'x': 128, 'r0_': 16},
432
+ reduction_hint=ReductionHint.DEFAULT,
433
+ filename=__file__,
434
+ triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr2': '*i32', 'out_ptr3': '*i32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
435
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 1024, 'r0_': 16384}}
436
+ )
437
+ @triton.jit
438
+ def triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3(in_ptr0, out_ptr2, out_ptr3, xnumel, r0_numel, XBLOCK : tl.constexpr):
439
+ xnumel = 128
440
+ r0_numel = 16
441
+ R0_BLOCK: tl.constexpr = 16
442
+ rnumel = r0_numel
443
+ RBLOCK: tl.constexpr = R0_BLOCK
444
+ xoffset = tl.program_id(0) * XBLOCK
445
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
446
+ xmask = xindex < xnumel
447
+ r0_index = tl.arange(0, R0_BLOCK)[None, :]
448
+ r0_offset = 0
449
+ r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
450
+ roffset = r0_offset
451
+ rindex = r0_index
452
+ r0_2 = r0_index
453
+ x0 = (xindex % 16)
454
+ x1 = xindex // 16
455
+ x3 = xindex
456
+ tmp0 = tl.load(in_ptr0 + (x0 + 17*r0_2 + 272*x1), xmask, other=0.0)
457
+ tmp1 = r0_2
458
+ tmp2 = tmp1.to(tl.int16)
459
+ tmp3 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
460
+ tmp4 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK])
461
+ tmp5, tmp6, = triton_helpers.sort_with_index(tmp3, tmp4, None, 1, stable=True, descending=True)
462
+ tmp7 = tmp0.to(tl.int64)
463
+ tmp8 = tl.broadcast_to(tmp7, [XBLOCK, R0_BLOCK])
464
+ tmp10 = tl.where(xmask, tmp8, 0)
465
+ tmp11 = tl.sum(tmp10, 1)[:, None].to(tl.int64)
466
+ tmp12 = tmp6.to(tl.int64)
467
+ tmp13 = tmp12.to(tl.int32)
468
+ tmp14 = tmp11.to(tl.int32)
469
+ tl.store(out_ptr2 + (r0_2 + 16*x3), tmp13, xmask)
470
+ tl.store(out_ptr3 + (x3), tmp14, xmask)
471
+ ''', device_str='cuda')
472
+
473
+
474
+ async_compile.wait(globals())
475
+ del async_compile
476
+
477
+ class Runner:
478
+ def __init__(self, partitions):
479
+ self.partitions = partitions
480
+
481
+ def recursively_apply_fns(self, fns):
482
+ new_callables = []
483
+ for fn, c in zip(fns, self.partitions):
484
+ new_callables.append(fn(c))
485
+ self.partitions = new_callables
486
+
487
+ def call(self, args):
488
+ arg0_1, = args
489
+ args.clear()
490
+ assert_size_stride(arg0_1, (8, ), (1, ))
491
+ with torch.cuda._DeviceGuard(1):
492
+ torch.cuda.set_device(1)
493
+ buf0 = empty_strided_cuda((8, 1, 16, 16), (256, 2048, 16, 1), torch.int64)
494
+ # Topologically Sorted Source Nodes: [result_1, m, causal_mask, n, b, index, lt, padding_mask, index_1, lt_1, and_2, suffix_mask, remainder, index_2, padding_mask_1, and_3, and_4, sub, remainder_1, diagnol_mask, result_2, batched_outputs_2, mask_2, mask_3, mask_block_sum], Original ATen: [aten.view, aten.arange, aten.ge, aten.index, aten.lt, aten.bitwise_and, aten.bitwise_or, aten.remainder, aten.sub, aten.eq, aten.permute, aten.sum]
495
+ stream1 = get_raw_stream(1)
496
+ triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0.run(arg0_1, buf0, 2048, 16384, stream=stream1)
497
+ del arg0_1
498
+ buf15 = empty_strided_cuda((8, 1, 16, 17), (272, 272, 17, 1), torch.int32)
499
+ # Topologically Sorted Source Nodes: [dense_mask_4], Original ATen: [aten.new_zeros]
500
+ stream1 = get_raw_stream(1)
501
+ triton_poi_fused_new_zeros_1.run(buf15, 2176, stream=stream1)
502
+ buf8 = empty_strided_cuda((8, 1, 16, 17), (272, 272, 17, 1), torch.int32)
503
+ # Topologically Sorted Source Nodes: [dense_mask_2], Original ATen: [aten.new_zeros]
504
+ stream1 = get_raw_stream(1)
505
+ triton_poi_fused_new_zeros_1.run(buf8, 2176, stream=stream1)
506
+ buf6 = empty_strided_cuda((8, 1, 16), (16, 16, 1), torch.int32)
507
+ buf13 = empty_strided_cuda((8, 1, 16), (16, 16, 1), torch.int32)
508
+ buf7 = empty_strided_cuda((8, 1, 16, 16), (256, 256, 16, 1), torch.int32)
509
+ buf14 = empty_strided_cuda((8, 1, 16, 16), (256, 256, 16, 1), torch.int32)
510
+ # Topologically Sorted Source Nodes: [gt, lt_3, partial_blocks, partial_blocks_1, dense_mask, col_indices, full_blocks, full_blocks_1, dense_mask_1, col_indices_1, dense_mask_2, setitem, arange_4, row_indices, col_range, num_blocks_in_row, child_3, unsqueeze_1, index_mask, child_4, valid_indices, dense_mask_4, setitem_1, arange_6, row_indices_1, col_range_1, num_blocks_in_row_1, child_7, unsqueeze_3, index_mask_1, child_8, valid_indices_1], Original ATen: [aten.gt, aten.lt, aten.bitwise_and, aten._to_copy, aten.sort, aten.eq, aten.new_zeros, aten.arange, aten.unsqueeze, aten.sum, aten.scalar_tensor, aten.where, aten.view, aten.index_put]
511
+ stream1 = get_raw_stream(1)
512
+ triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.run(buf0, buf6, buf13, buf7, buf8, buf14, buf15, 128, 16, stream=stream1)
513
+ del buf0
514
+ buf22 = empty_strided_cuda((8, 1, 16, 16), (256, 256, 16, 1), torch.int32)
515
+ buf24 = empty_strided_cuda((8, 1, 16), (16, 16, 1), torch.int32)
516
+ # Topologically Sorted Source Nodes: [batched_outputs_3, transpose, col_indices_2, q_indices, num_blocks_in_row_2, q_num_blocks], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sort, aten._to_copy, aten.sum]
517
+ stream1 = get_raw_stream(1)
518
+ triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3.run(buf8, buf22, buf24, 128, 16, stream=stream1)
519
+ del buf8
520
+ buf19 = empty_strided_cuda((8, 1, 16, 16), (256, 256, 16, 1), torch.int32)
521
+ buf21 = empty_strided_cuda((8, 1, 16), (16, 16, 1), torch.int32)
522
+ # Topologically Sorted Source Nodes: [batched_outputs_5, transpose_1, col_indices_3, full_q_indices, num_blocks_in_row_3, full_q_num_blocks], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sort, aten._to_copy, aten.sum]
523
+ stream1 = get_raw_stream(1)
524
+ triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3.run(buf15, buf19, buf21, 128, 16, stream=stream1)
525
+ del buf15
526
+ return (buf19, buf21, buf22, buf24, buf14, buf13, buf7, buf6, )
527
+
528
+ runner = Runner(partitions=[])
529
+ call = runner.call
530
+ recursively_apply_fns = runner.recursively_apply_fns
531
+
532
+
533
+ def benchmark_compiled_module(times=10, repeat=10):
534
+ from torch._dynamo.testing import rand_strided
535
+ from torch._inductor.utils import print_performance
536
+ arg0_1 = rand_strided((8, ), (1, ), device='cuda:1', dtype=torch.int64)
537
+ fn = lambda: call([arg0_1])
538
+ return print_performance(fn, times=times, repeat=repeat)
539
+
540
+
541
+ if __name__ == "__main__":
542
+ from torch._inductor.wrapper_benchmark import compiled_module_main
543
+ compiled_module_main('None', benchmark_compiled_module)
SpecForge-ext/cache/compiled_kernels/5p/c5pbkg5eq64emuv25ukki7a5dxvn2p2sh6jeiwb6b54tbidps5w7.py ADDED
@@ -0,0 +1,552 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+
9
+ @triton_heuristics.template(
10
+
11
+ num_stages=3,
12
+ num_warps=8,
13
+ triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'in_ptr9': '*i64', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
15
+
16
+ )
17
+ @triton.jit
18
+ def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0):
19
+ PRESCALE_QK : tl.constexpr = False
20
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
21
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
22
+ WRITE_DQ : tl.constexpr = True
23
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
24
+ OUTPUT_MAX : tl.constexpr = False
25
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
26
+ IS_DIVISIBLE : tl.constexpr = True
27
+ SM_SCALE : tl.constexpr = 0.08838834764831843
28
+ GQA_SHARED_HEADS : tl.constexpr = 4
29
+ HAS_FULL_BLOCKS : tl.constexpr = True
30
+ QK_HEAD_DIM : tl.constexpr = 128
31
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
32
+ V_HEAD_DIM : tl.constexpr = 128
33
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
34
+ SAFE_HEAD_DIM : tl.constexpr = True
35
+ USE_TMA : tl.constexpr = False
36
+ BLOCK_M : tl.constexpr = 128
37
+ BLOCK_N : tl.constexpr = 64
38
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
39
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
40
+ INDEX_DTYPE : tl.constexpr = tl.int32
41
+ Q = arg_Q
42
+ K = arg_K
43
+ V = arg_V
44
+ LSE = arg_LSE
45
+ MAX = arg_MAX
46
+ KV_NUM_BLKS = arg_KV_NUM_BLKS
47
+ KV_IDX = arg_KV_IDX
48
+ FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
49
+ FULL_KV_IDX = arg_FULL_KV_IDX
50
+
51
+ # Sub notation for this kernel:
52
+ #
53
+ # Q: Query, K: Key, V: Value
54
+ # M: Number of queries, N: Number of keys/values, D: Model dimension
55
+ # QK_HEAD_DIM: The dimension of the query and key embeddings
56
+ # V_HEAD_DIM: The dimension of the value embeddings
57
+ # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head
58
+ # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
59
+ #
60
+ # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
61
+ # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
62
+ # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
63
+ # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
64
+ # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
65
+ #
66
+ # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad
67
+ #
68
+ # (Modifiable) Performance tuning options
69
+ # BLOCK_M: The thread block size across the seqlen dim of Q.
70
+ # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block.
71
+
72
+ # The below are kernel options that can be applied for certain score_mods,
73
+ # or involve a numerics vs. perf tradeoff
74
+ # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
75
+ # about 20% more numerical error, but slightly faster.
76
+ # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row
77
+ # is not masked out? If so, we can skip an extra safety check
78
+ # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are
79
+ # contiguous? If so, we don't need to do an indirect jump for every block
80
+
81
+ tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0)
82
+ tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0)
83
+
84
+ # Define strides of inputs
85
+ stride_qz, stride_qh, stride_qm, stride_qk = 8388608, 128, 4096, 1
86
+ stride_kz, stride_kh, stride_kn, stride_kk = 2097152, 262144, 128, 1
87
+ stride_vz, stride_vh, stride_vn, stride_vk = 2097152, 262144, 128, 1
88
+
89
+ ZQ = 8
90
+ HQ = 32
91
+ Q_LEN = 2048
92
+ ZKV = 8
93
+ KV_LEN = 2048
94
+
95
+ MATMUL_PRECISION = Q.dtype.element_ty
96
+
97
+ q_start = tl.program_id(0).to(INDEX_DTYPE)
98
+ off_zq = tl.program_id(1).to(INDEX_DTYPE)
99
+ off_hq = tl.program_id(2).to(INDEX_DTYPE)
100
+
101
+ # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq.
102
+ # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0.
103
+ off_zkv = off_zq % ZKV
104
+ off_hkv = off_hq // GQA_SHARED_HEADS
105
+ off_g = off_hq % GQA_SHARED_HEADS
106
+
107
+ q_offset = off_zq * stride_qz + off_hq * stride_qh
108
+ k_offset = off_zkv * stride_kz + off_hkv * stride_kh
109
+ v_offset = off_zkv * stride_vz + off_hkv * stride_vh
110
+
111
+ Q = Q + q_offset
112
+ K = K + k_offset
113
+ V = V + v_offset
114
+
115
+ # Setting up the TMA descriptors for Q, K, V
116
+ desc_q = None
117
+ desc_k = None
118
+ desc_v = None
119
+
120
+ SPARSE_Z = 8
121
+ SPARSE_HQ = 1
122
+
123
+ sparse_idx_z = off_zq % SPARSE_Z
124
+ sparse_idx_hq = off_hq % SPARSE_HQ
125
+
126
+ SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M)
127
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
128
+
129
+ stride_kv_num_blks_h = 16
130
+ stride_kv_idx_h = 256
131
+ stride_kv_idx_m = 16
132
+
133
+ # initialize pointer to m and l
134
+ m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
135
+ l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
136
+ acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
137
+
138
+ offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
139
+
140
+ # KV_IDX and KV_NUM_BLKS are always contiguous.
141
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq
142
+ sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE
143
+ sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950
144
+ offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
145
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
146
+ q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
147
+
148
+ # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
149
+ # We don't know anything "special" about these blocks, so we need to apply
150
+ # both score_mod and mask_mod to it
151
+ kv_indices = KV_IDX + sparse_kv_idx_offset
152
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
153
+ kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
154
+ block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
155
+
156
+
157
+ # K and V pointers will be passed directly to forward_inner
158
+
159
+ offs_n = kv_start + tl.arange(0, BLOCK_N)
160
+
161
+
162
+ acc, l_i, m_i = forward_inner(
163
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0,
164
+ q, K, V,
165
+ desc_k, desc_v, Q_LEN, KV_LEN,
166
+ acc, l_i, m_i,
167
+ off_zq, off_hq, offs_m[:, None], offs_n[None, :],
168
+ kv_start,
169
+ kv_indices, kv_num_blocks,
170
+ 0, block_n_end,
171
+ MATMUL_PRECISION,
172
+ stride_kk, stride_kn, stride_vn, stride_vk,
173
+ IS_FULL_BLOCKS=False,
174
+ )
175
+
176
+ # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
177
+ # We know these blocks are guaranteed to be "full", so we don't need to
178
+ # apply mask_mod to them - only score_mod
179
+ if HAS_FULL_BLOCKS:
180
+ # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
181
+ kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
182
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
183
+ kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
184
+ block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
185
+ # K and V pointers will be passed directly to forward_inner
186
+ offs_n = kv_start + tl.arange(0, BLOCK_N)
187
+
188
+ acc, l_i, m_i = forward_inner(
189
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0,
190
+ q, K, V,
191
+ desc_k, desc_v, Q_LEN, KV_LEN,
192
+ acc, l_i, m_i,
193
+ off_zq, off_hq, offs_m[:, None], offs_n[None, :],
194
+ kv_start,
195
+ kv_indices, kv_num_blocks,
196
+ 0, block_n_end,
197
+ MATMUL_PRECISION,
198
+ stride_kk, stride_kn, stride_vn, stride_vk,
199
+ IS_FULL_BLOCKS=True,
200
+ )
201
+
202
+
203
+ # [Note] Handle fully masked out rows:
204
+ # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf.
205
+ # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step
206
+ l_i = tl.where(l_i == 0.0, 1, l_i)
207
+
208
+ acc = acc / l_i[:, None]
209
+ idx_zq = tl.program_id(1).to(INDEX_DTYPE)
210
+ idx_hq = tl.program_id(2).to(INDEX_DTYPE)
211
+ idx_m = offs_m[:, None].to(INDEX_DTYPE)
212
+ idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE)
213
+
214
+ mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM)
215
+
216
+ tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED])
217
+ xindex = idx_d + 128*idx_m + 262144*idx_hq + 8388608*idx_zq
218
+ tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m + 8388608*idx_zq, acc.shape)), acc, mask)
219
+
220
+ if OUTPUT_LOGSUMEXP:
221
+ off_hz = off_zq * HQ + off_hq
222
+ l_ptrs = LSE + off_hz * Q_LEN + offs_m
223
+ lse = m_i + tl.math.log2(l_i)
224
+ if IS_DIVISIBLE:
225
+ tl.store(l_ptrs, lse)
226
+ else:
227
+ tl.store(l_ptrs, lse, mask=offs_m < Q_LEN)
228
+
229
+ if OUTPUT_MAX:
230
+ off_hz = off_zq * HQ + off_hq
231
+ max_ptrs = MAX + off_hz * Q_LEN + offs_m
232
+ if IS_DIVISIBLE:
233
+ tl.store(max_ptrs, m_i)
234
+ else:
235
+ tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN)
236
+
237
+
238
+ # Utility triton funcs
239
+ @triton.jit
240
+ def get_offset_for_next_block(
241
+ loop_iter, col_indices, total_blocks,
242
+ SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
243
+ BLOCKS_ARE_CONTIGUOUS: tl.constexpr
244
+ ):
245
+ if BLOCKS_ARE_CONTIGUOUS:
246
+ return BLOCK
247
+ cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
248
+ cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
249
+ next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
250
+ needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
251
+ jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
252
+ offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
253
+ return offset
254
+
255
+ @triton.jit
256
+ def get_bounded_indices(indices, max_len=None):
257
+ return indices % max_len if max_len is not None else indices
258
+
259
+ @triton.jit
260
+ def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
261
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
262
+ return tl.load(block_ptr)
263
+ elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
264
+ return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
265
+ elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
266
+ return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
267
+ else:
268
+ return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
269
+
270
+ @triton.jit
271
+ def load_checked_2d(
272
+ ptr,
273
+ offs_m,
274
+ offs_n,
275
+ stride_m,
276
+ stride_n,
277
+ IS_DIVISIBLE_M: tl.constexpr,
278
+ IS_DIVISIBLE_N: tl.constexpr,
279
+ M_LEN: tl.constexpr,
280
+ N_LEN: tl.constexpr,
281
+ ):
282
+ # Calculate final pointer if strides are provided
283
+ if stride_m is not None and stride_n is not None:
284
+ ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
285
+
286
+ # Handle all masking cases
287
+ if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
288
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
289
+ elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
290
+ return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
291
+ elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
292
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
293
+ else: # Both divisible
294
+ return tl.load(ptr)
295
+
296
+
297
+ # Common Imports
298
+ @triton.jit
299
+ def forward_block_mn(
300
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0,
301
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
302
+ # accumulated values
303
+ acc, l_i, m_i,
304
+ # Offsets
305
+ off_z, off_h, offs_m, offs_n,
306
+ # Offsets needed for TMA loads
307
+ kv_start,
308
+ kv_offset,
309
+ MATMUL_PRECISION, RCP_LN2,
310
+ # Strides for K and V
311
+ stride_kk, stride_kn, stride_vn, stride_vk,
312
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
313
+
314
+ ):
315
+ # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
316
+ PRESCALE_QK : tl.constexpr = False
317
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
318
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
319
+ WRITE_DQ : tl.constexpr = True
320
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
321
+ OUTPUT_MAX : tl.constexpr = False
322
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
323
+ IS_DIVISIBLE : tl.constexpr = True
324
+ SM_SCALE : tl.constexpr = 0.08838834764831843
325
+ GQA_SHARED_HEADS : tl.constexpr = 4
326
+ HAS_FULL_BLOCKS : tl.constexpr = True
327
+ QK_HEAD_DIM : tl.constexpr = 128
328
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
329
+ V_HEAD_DIM : tl.constexpr = 128
330
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
331
+ SAFE_HEAD_DIM : tl.constexpr = True
332
+ USE_TMA : tl.constexpr = False
333
+ BLOCK_M : tl.constexpr = 128
334
+ BLOCK_N : tl.constexpr = 64
335
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
336
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
337
+ INDEX_DTYPE : tl.constexpr = tl.int32
338
+
339
+
340
+ # -- load k --
341
+ # NB reversed order to since K is transposed
342
+ kv_base_offset = kv_start + kv_offset
343
+
344
+ # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N]
345
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
346
+ offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N)
347
+ k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
348
+
349
+ k = tl.trans(k)
350
+ # -- compute qk ---
351
+ qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2.
352
+ if not PRESCALE_QK:
353
+ qk *= SM_SCALE
354
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
355
+ # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements,
356
+ # which is larger than the actual number of elements. To avoid access memory out of bound,
357
+ # we need to mask out the elements that are out of Q_LEN & KV_LEN.
358
+ m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None)
359
+ n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None)
360
+
361
+ tmp0 = (qk)
362
+ post_mod_scores = tmp0
363
+
364
+
365
+ if CHECK_BLOCK_BOUNDARY:
366
+ # Mask out the elements that are out of the KV_LEN for non divisible seqlen.
367
+ post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf"))
368
+
369
+ if not IS_FULL_BLOCKS:
370
+ tmp1 = tl.full([1], False, tl.int1)
371
+ tmp2 = (m)
372
+ tmp3 = (n)
373
+ tmp4 = tmp2 >= tmp3
374
+ tmp5 = tmp3.to(tl.int64)
375
+ tmp6 = (off_z)
376
+ tmp7 = tl.load(in_ptr9 + tmp6)
377
+ tmp8 = tmp5 < tmp7
378
+ tmp9 = tmp2.to(tl.int64)
379
+ tmp10 = tmp9 < tmp7
380
+ tmp11 = tmp8 & tmp10
381
+ tmp12 = tmp4 & tmp11
382
+ tmp13 = tmp1 | tmp12
383
+ tmp14 = tl.full([1], 2048, tl.int32)
384
+ tmp15 = tmp3 >= tmp14
385
+ tmp16 = (tmp3 % tmp14)
386
+ tmp17 = tl.full([1], 0, tl.int32)
387
+ tmp18 = tmp16 != tmp17
388
+ tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0
389
+ tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0
390
+ tmp21 = tmp19 != tmp20
391
+ tmp22 = tmp18 & tmp21
392
+ tmp23 = tmp16 + tmp14
393
+ tmp24 = tl.where(tmp22, tmp23, tmp16)
394
+ tmp25 = tmp24.to(tl.int64)
395
+ tmp26 = tmp25 < tmp7
396
+ tmp27 = tmp15 & tmp26
397
+ tmp28 = tmp3 - tmp2
398
+ tmp29 = (tmp28 % tmp14)
399
+ tmp30 = tmp29 != tmp17
400
+ tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0
401
+ tmp32 = tmp31 != tmp20
402
+ tmp33 = tmp30 & tmp32
403
+ tmp34 = tmp29 + tmp14
404
+ tmp35 = tl.where(tmp33, tmp34, tmp29)
405
+ tmp36 = tmp35 == tmp17
406
+ tmp37 = tmp27 & tmp36
407
+ tmp38 = tmp13 | tmp37
408
+ mask_mod_output = tmp38
409
+
410
+
411
+ if CHECK_BLOCK_BOUNDARY:
412
+ mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False)
413
+ # apply mask for partially unmasked blocks
414
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
415
+
416
+ if not PRESCALE_QK:
417
+ post_mod_scores *= RCP_LN2
418
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
419
+
420
+ # -- compute scaling constant ---
421
+ m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1))
422
+ if not ROWS_GUARANTEED_SAFE:
423
+ masked_out_rows = (m_ij == float("-inf"))
424
+ m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
425
+ else:
426
+ m_ij_masked = m_ij
427
+
428
+ alpha = tl.math.exp2(m_i - m_ij_masked)
429
+ p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None])
430
+
431
+ # NB: l_i update is pulled up here since it's a bit faster
432
+ # NB: For headdim=256, it's faster to move it back down to after m_i =
433
+ # m_ij
434
+ l_i = l_i * alpha + tl.sum(p, 1)
435
+ # # -- scale and update acc --
436
+ acc = acc * alpha[:, None]
437
+ # Calculate offsets for V loading - reuse kv_base_offset from K loading
438
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
439
+ v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
440
+ acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION)
441
+
442
+ # -- update m_i
443
+ m_i = m_ij
444
+
445
+ return acc, l_i, m_i
446
+
447
+ @triton.jit
448
+ def forward_inner(
449
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0,
450
+ q, K, V,
451
+ desc_k, desc_v, Q_LEN, KV_LEN,
452
+ # accumulated values
453
+ acc, l_i, m_i,
454
+ # Offsets used as inputs to score_mod & mask_mod
455
+ # of size [BLOCK_M, BLOCK_N] or scalar.
456
+ off_z, off_h, offs_m, offs_n,
457
+ # Offsets needed for TMA loads
458
+ kv_start,
459
+ # blocksparse data
460
+ kv_indices, kv_num_blocks,
461
+ # start kv and end kv block
462
+ block_n_start, block_n_end,
463
+ MATMUL_PRECISION,
464
+ # Strides for K and V
465
+ stride_kk, stride_kn, stride_vn, stride_vk,
466
+ IS_FULL_BLOCKS,
467
+ ):
468
+ # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
469
+ PRESCALE_QK : tl.constexpr = False
470
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
471
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
472
+ WRITE_DQ : tl.constexpr = True
473
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
474
+ OUTPUT_MAX : tl.constexpr = False
475
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
476
+ IS_DIVISIBLE : tl.constexpr = True
477
+ SM_SCALE : tl.constexpr = 0.08838834764831843
478
+ GQA_SHARED_HEADS : tl.constexpr = 4
479
+ HAS_FULL_BLOCKS : tl.constexpr = True
480
+ QK_HEAD_DIM : tl.constexpr = 128
481
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
482
+ V_HEAD_DIM : tl.constexpr = 128
483
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
484
+ SAFE_HEAD_DIM : tl.constexpr = True
485
+ USE_TMA : tl.constexpr = False
486
+ BLOCK_M : tl.constexpr = 128
487
+ BLOCK_N : tl.constexpr = 64
488
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
489
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
490
+ INDEX_DTYPE : tl.constexpr = tl.int32
491
+
492
+
493
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
494
+ RCP_LN2: tl.constexpr = 1.44269504
495
+
496
+ if PRESCALE_QK:
497
+ q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
498
+
499
+ kv_offset = 0
500
+
501
+ # loop over k, v and update accumulator until block_n_end
502
+ for start_n in range(block_n_start, block_n_end):
503
+ # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention.
504
+ if IS_DIVISIBLE:
505
+ acc, l_i, m_i = forward_block_mn(
506
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0,
507
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
508
+ # accumulated values
509
+ acc, l_i, m_i,
510
+ # Offsets
511
+ off_z, off_h, offs_m, offs_n,
512
+ # Offsets needed for TMA loads
513
+ kv_start,
514
+ kv_offset,
515
+ MATMUL_PRECISION, RCP_LN2,
516
+ # Strides for K and V
517
+ stride_kk, stride_kn, stride_vn, stride_vk,
518
+ IS_FULL_BLOCKS,
519
+ )
520
+ else:
521
+ # Benchmark shows even we applied mod & mask to each block for non divisible seqlen,
522
+ # it's on par or slightly faster than only applying to the last block in fwd.
523
+ # However, we choose different strategy for bwd, where we only apply mod & mask
524
+ # to the last block because it's faster a lot.
525
+ acc, l_i, m_i = forward_block_mn(
526
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0,
527
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
528
+ # accumulated values
529
+ acc, l_i, m_i,
530
+ # Offsets
531
+ off_z, off_h, offs_m, offs_n,
532
+ # Offsets needed for TMA loads
533
+ kv_start,
534
+ kv_offset,
535
+ MATMUL_PRECISION, RCP_LN2,
536
+ # Strides for K and V
537
+ stride_kk, stride_kn, stride_vn, stride_vk,
538
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
539
+ )
540
+
541
+
542
+
543
+ offset = get_offset_for_next_block(
544
+ start_n, kv_indices, kv_num_blocks,
545
+ SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS
546
+ )
547
+
548
+ offs_n = offs_n + offset
549
+ kv_offset += offset
550
+
551
+
552
+ return acc, l_i, m_i
SpecForge-ext/cache/compiled_kernels/5s/c5siycmobmba5rqczjfbtd45di6el6qnpizugzs3hsg4jzkcqnpk.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AOT ID: ['11_inference']
2
+ from ctypes import c_void_p, c_long, c_int
3
+ import torch
4
+ import math
5
+ import random
6
+ import os
7
+ import tempfile
8
+ from math import inf, nan
9
+ from cmath import nanj
10
+ from torch._inductor.hooks import run_intermediate_hooks
11
+ from torch._inductor.utils import maybe_profile
12
+ from torch._inductor.codegen.memory_planning import _align as align
13
+ from torch import device, empty_strided
14
+ from torch._inductor.async_compile import AsyncCompile
15
+ from torch._inductor.select_algorithm import extern_kernels
16
+ import triton
17
+ import triton.language as tl
18
+ from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
19
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
20
+
21
+ aten = torch.ops.aten
22
+ inductor_ops = torch.ops.inductor
23
+ _quantized = torch.ops._quantized
24
+ assert_size_stride = torch._C._dynamo.guards.assert_size_stride
25
+ assert_alignment = torch._C._dynamo.guards.assert_alignment
26
+ empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
27
+ empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
28
+ empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
29
+ empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
30
+ empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
31
+ reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
32
+ alloc_from_pool = torch.ops.inductor._alloc_from_pool
33
+ async_compile = AsyncCompile()
34
+ empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
35
+
36
+
37
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/lw/clwnecq6ifpvev5aiszbhu6i732z6eomppbbe2l6ohgsvjmgczzn.py
38
+ # Topologically Sorted Source Nodes: [target_head, target_p], Original ATen: [aten._to_copy, prims.prepare_softmax_online, aten.sub, aten.exp, aten._softmax]
39
+ # Source node to ATen node mapping:
40
+ # target_head => convert_element_type
41
+ # target_p => div
42
+ # Graph fragment:
43
+ # %arg1_1 : Tensor "bf16[2, s67, 32000][32000*s67, 32000, 1]cuda:4" = PlaceHolder[target=arg1_1]
44
+ # %getitem : Tensor "f32[2, s67, 1][s67, 1, 2*s67]cuda:4" = PlaceHolder[target=getitem]
45
+ # %getitem_1 : Tensor "f32[2, s67, 1][s67, 1, 2*s67]cuda:4" = PlaceHolder[target=getitem_1]
46
+ # %convert_element_type : Tensor "f32[2, s67, 32000][32000*s67, 32000, 1]cuda:4"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%arg1_1, torch.float32), kwargs = {})
47
+ # %prepare_softmax_online_default : [num_users=2] = call_function[target=torch.ops.prims.prepare_softmax_online.default](args = (%convert_element_type, 2), kwargs = {})
48
+ # %sub_tensor : Tensor "f32[2, s67, 32000][32000*s67, 32000, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%convert_element_type, %getitem), kwargs = {})
49
+ # %exp_default : Tensor "f32[2, s67, 32000][32000*s67, 32000, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.exp.default](args = (%sub_tensor,), kwargs = {})
50
+ # %div : Tensor "f32[2, s67, 32000][32000*s67, 32000, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%exp_default, %getitem_1), kwargs = {})
51
+ # return %getitem,%getitem_1,%div
52
+ triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0 = async_compile.triton('triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0', '''
53
+ import triton
54
+ import triton.language as tl
55
+
56
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
57
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
58
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
59
+ triton_helpers.set_driver_to_gpu()
60
+
61
+ @triton_heuristics.reduction(
62
+ size_hints={'x': 4096, 'r0_': 32768},
63
+ reduction_hint=ReductionHint.INNER,
64
+ filename=__file__,
65
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr2': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
66
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
67
+ )
68
+ @triton.jit
69
+ def triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0(in_ptr0, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
70
+ r0_numel = 32000
71
+ rnumel = r0_numel
72
+ RBLOCK: tl.constexpr = R0_BLOCK
73
+ xoffset = tl.program_id(0) * XBLOCK
74
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
75
+ xmask = xindex < xnumel
76
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
77
+ rbase = r0_base
78
+ x0 = xindex
79
+ _tmp3_max = tl.full([XBLOCK, R0_BLOCK], float('-inf'), tl.float32)
80
+ _tmp3_sum = tl.zeros([XBLOCK, R0_BLOCK], tl.float32)
81
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
82
+ r0_index = r0_offset + r0_base
83
+ r0_mask = r0_index < r0_numel
84
+ roffset = r0_offset
85
+ rindex = r0_index
86
+ r0_1 = r0_index
87
+ tmp0 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
88
+ tmp1 = tmp0.to(tl.float32)
89
+ tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK])
90
+
91
+ _tmp3_max_next, _tmp3_sum_next = triton_helpers.online_softmax_combine(
92
+ _tmp3_max, _tmp3_sum, tmp2, False
93
+ )
94
+
95
+ _tmp3_max = tl.where(r0_mask & xmask, _tmp3_max_next, _tmp3_max)
96
+ _tmp3_sum = tl.where(r0_mask & xmask, _tmp3_sum_next, _tmp3_sum)
97
+
98
+ tmp3, tmp4 = triton_helpers.online_softmax_reduce(
99
+ _tmp3_max, _tmp3_sum, 1, False)
100
+ tmp3 = tmp3[:, None]
101
+ tmp4 = tmp4[:, None]
102
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
103
+ r0_index = r0_offset + r0_base
104
+ r0_mask = r0_index < r0_numel
105
+ roffset = r0_offset
106
+ rindex = r0_index
107
+ r0_1 = r0_index
108
+ tmp5 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
109
+ tmp6 = tmp5.to(tl.float32)
110
+ tmp7 = tmp6 - tmp3
111
+ tmp8 = libdevice.exp(tmp7)
112
+ tmp9 = (tmp8 / tmp4)
113
+ tl.store(out_ptr2 + (r0_1 + 32000*x0), tmp9, r0_mask & xmask)
114
+ ''', device_str='cuda')
115
+
116
+
117
+ async_compile.wait(globals())
118
+ del async_compile
119
+
120
+ class Runner:
121
+ def __init__(self, partitions):
122
+ self.partitions = partitions
123
+
124
+ def recursively_apply_fns(self, fns):
125
+ new_callables = []
126
+ for fn, c in zip(fns, self.partitions):
127
+ new_callables.append(fn(c))
128
+ self.partitions = new_callables
129
+
130
+ def call(self, args):
131
+ arg0_1, arg1_1 = args
132
+ args.clear()
133
+ s67 = arg0_1
134
+ assert_size_stride(arg1_1, (2, s67, 32000), (32000*s67, 32000, 1))
135
+ with torch.cuda._DeviceGuard(4):
136
+ torch.cuda.set_device(4)
137
+ buf2 = empty_strided_cuda((2, s67, 32000), (32000*s67, 32000, 1), torch.float32)
138
+ # Topologically Sorted Source Nodes: [target_head, target_p], Original ATen: [aten._to_copy, prims.prepare_softmax_online, aten.sub, aten.exp, aten._softmax]
139
+ triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0_xnumel = 2*s67
140
+ stream4 = get_raw_stream(4)
141
+ triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0.run(arg1_1, buf2, triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0_xnumel, 32000, stream=stream4)
142
+ del arg1_1
143
+ return (buf2, )
144
+
145
+ runner = Runner(partitions=[])
146
+ call = runner.call
147
+ recursively_apply_fns = runner.recursively_apply_fns
148
+
149
+
150
+ def benchmark_compiled_module(times=10, repeat=10):
151
+ from torch._dynamo.testing import rand_strided
152
+ from torch._inductor.utils import print_performance
153
+ arg0_1 = 1543
154
+ arg1_1 = rand_strided((2, 1543, 32000), (49376000, 32000, 1), device='cuda:4', dtype=torch.bfloat16)
155
+ fn = lambda: call([arg0_1, arg1_1])
156
+ return print_performance(fn, times=times, repeat=repeat)
157
+
158
+
159
+ if __name__ == "__main__":
160
+ from torch._inductor.wrapper_benchmark import compiled_module_main
161
+ compiled_module_main('None', benchmark_compiled_module)
SpecForge-ext/cache/compiled_kernels/5u/235c5fbee66a14cc3d65896905ec816ec90c51ba6594c4a627960306977eb07c.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 32, "R0_BLOCK": 16, "num_warps": 4, "num_stages": 1, "configs_hash": "21ad1ee516cd6d15e1fb8e88c10082cd54bef654f8a281c7d5ccd54b6509a685", "found_by_coordesc": false, "time_taken_ms": 28, "triton_cache_hash": "2HBOMUT44J5WFCUWYGRFAAS3HGVNDHLHT7HCSXUCAOIKU6XGJNTA"}
SpecForge-ext/cache/compiled_kernels/6b/c6beknosybos5d54llineldguuueh3kpjlkiuzm4pkorx7g6mjh6.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.reduction(
11
+ size_hints={'x': 4096, 'r0_': 32768},
12
+ reduction_hint=ReductionHint.INNER,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
16
+ )
17
+ @triton.jit
18
+ def triton_red_fused_argmax_0(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
19
+ r0_numel = 32000
20
+ rnumel = r0_numel
21
+ RBLOCK: tl.constexpr = R0_BLOCK
22
+ xoffset = tl.program_id(0) * XBLOCK
23
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
24
+ xmask = xindex < xnumel
25
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
26
+ rbase = r0_base
27
+ x0 = xindex
28
+ _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32)
29
+ _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32)
30
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
31
+ r0_index = r0_offset + r0_base
32
+ r0_mask = r0_index < r0_numel
33
+ roffset = r0_offset
34
+ rindex = r0_index
35
+ r0_1 = r0_index
36
+ tmp0 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
37
+ tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
38
+ _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index(
39
+ _tmp2, _tmp2_index, tmp1, rindex
40
+ )
41
+ _tmp2 = tl.where(r0_mask & xmask, _tmp2_next, _tmp2)
42
+ _tmp2_index = tl.where(r0_mask & xmask, _tmp2_index_next, _tmp2_index)
43
+ tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1)
44
+ tmp2 = tmp2_idx[:, None]
45
+ tl.store(out_ptr0 + (x0), tmp2, xmask)
SpecForge-ext/cache/compiled_kernels/6b/c6bpf3ctcqs5wvcac26go3fcp5hdc2pxduwgba2cnxt52xqmp6mq.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AOT ID: ['2_backward']
2
+ from ctypes import c_void_p, c_long, c_int
3
+ import torch
4
+ import math
5
+ import random
6
+ import os
7
+ import tempfile
8
+ from math import inf, nan
9
+ from cmath import nanj
10
+ from torch._inductor.hooks import run_intermediate_hooks
11
+ from torch._inductor.utils import maybe_profile
12
+ from torch._inductor.codegen.memory_planning import _align as align
13
+ from torch import device, empty_strided
14
+ from torch._inductor.async_compile import AsyncCompile
15
+ from torch._inductor.select_algorithm import extern_kernels
16
+ import triton
17
+ import triton.language as tl
18
+ from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
19
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
20
+
21
+ aten = torch.ops.aten
22
+ inductor_ops = torch.ops.inductor
23
+ _quantized = torch.ops._quantized
24
+ assert_size_stride = torch._C._dynamo.guards.assert_size_stride
25
+ assert_alignment = torch._C._dynamo.guards.assert_alignment
26
+ empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
27
+ empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
28
+ empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
29
+ empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
30
+ empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
31
+ reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
32
+ alloc_from_pool = torch.ops.inductor._alloc_from_pool
33
+ async_compile = AsyncCompile()
34
+ empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
35
+
36
+
37
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/zg/czg53pk3l24wn74a6bylpzbgb44kx2zfplies7n5uiiogfzwg4z2.py
38
+ # Topologically Sorted Source Nodes: [hidden_states, hidden_states_1, to_1], Original ATen: [aten._to_copy, aten.mul, aten.sum]
39
+ # Source node to ATen node mapping:
40
+ # hidden_states => convert_element_type
41
+ # hidden_states_1 => mul_16
42
+ # to_1 => convert_element_type_1
43
+ # Graph fragment:
44
+ # %tangents_1 : Tensor "bf16[s47, s87, s33][s33*s87, s33, 1]cuda:7" = PlaceHolder[target=tangents_1]
45
+ # %primals_4 : Tensor "bf16[s47, s87, s33][s33*s87, s33, 1]cuda:7" = PlaceHolder[target=primals_4]
46
+ # %rsqrt : Tensor "f32[s47, s87, 1][s87, 1, 1]cuda:7" = PlaceHolder[target=rsqrt]
47
+ # %convert_element_type : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:7"[num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%primals_4, torch.float32), kwargs = {})
48
+ # %mul_16 : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type, %rsqrt), kwargs = {})
49
+ # %convert_element_type_1 : Tensor "bf16[s47, s87, s33][s33*s87, s33, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_16, torch.bfloat16), kwargs = {})
50
+ # %mul_28 : Tensor "bf16[s47, s87, s33][s33*s87, s33, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_1, %convert_element_type_1), kwargs = {})
51
+ # %sum_1 : Tensor "bf16[1, 1, s33][s33, s33, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_28, [0, 1], True), kwargs = {})
52
+ # return %buf0
53
+ triton_red_fused__to_copy_mul_sum_0 = async_compile.triton('triton_red_fused__to_copy_mul_sum_0', '''
54
+ import triton
55
+ import triton.language as tl
56
+
57
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
58
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
59
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
60
+ triton_helpers.set_driver_to_gpu()
61
+
62
+ @triton_heuristics.reduction(
63
+ size_hints={'x': 131072, 'r0_': 128},
64
+ reduction_hint=ReductionHint.OUTER,
65
+ filename=__file__,
66
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]},
67
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_mul_sum_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
68
+ )
69
+ @triton.jit
70
+ def triton_red_fused__to_copy_mul_sum_0(in_ptr0, in_ptr1, in_ptr2, out_ptr0, ks0, ks1, ks2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
71
+ rnumel = r0_numel
72
+ RBLOCK: tl.constexpr = R0_BLOCK
73
+ xoffset = tl.program_id(0) * XBLOCK
74
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
75
+ xmask = xindex < xnumel
76
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
77
+ rbase = r0_base
78
+ x1 = xindex // ks0
79
+ x0 = (xindex % ks0)
80
+ _tmp13 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
81
+ x3 = xindex
82
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
83
+ r0_index = r0_offset + r0_base
84
+ r0_mask = r0_index < r0_numel
85
+ roffset = r0_offset
86
+ rindex = r0_index
87
+ r0_2 = r0_index
88
+ tmp0 = r0_2 + x1*((31 + ks1*ks2) // 32)
89
+ tmp1 = ks1*ks2
90
+ tmp2 = tmp0 < tmp1
91
+ tmp3 = tl.load(in_ptr0 + (x0 + ks0*(((r0_2 + x1*((31 + ks1*ks2) // 32)) % (ks1*ks2)))), r0_mask & tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
92
+ tmp4 = tl.load(in_ptr1 + (x0 + ks0*(((r0_2 + x1*((31 + ks1*ks2) // 32)) % (ks1*ks2)))), r0_mask & tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
93
+ tmp5 = tmp4.to(tl.float32)
94
+ tmp6 = tl.load(in_ptr2 + (((r0_2 + x1*((31 + ks1*ks2) // 32)) % (ks1*ks2))), r0_mask & tmp2 & xmask, eviction_policy='evict_last', other=0.0)
95
+ tmp7 = tmp5 * tmp6
96
+ tmp8 = tmp7.to(tl.float32)
97
+ tmp9 = tmp3 * tmp8
98
+ tmp10 = tl.full(tmp9.shape, 0, tmp9.dtype)
99
+ tmp11 = tl.where(tmp2, tmp9, tmp10)
100
+ tmp12 = tl.broadcast_to(tmp11, [XBLOCK, R0_BLOCK])
101
+ tmp14 = _tmp13 + tmp12
102
+ _tmp13 = tl.where(r0_mask & xmask, tmp14, _tmp13)
103
+ tmp13 = tl.sum(_tmp13, 1)[:, None]
104
+ tl.store(out_ptr0 + (x3), tmp13, xmask)
105
+ ''', device_str='cuda')
106
+
107
+
108
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ut/cutp3chhk5c6s5fxb2gqzhrx5hjq4ltt3ybguoemttw3toknshg6.py
109
+ # Topologically Sorted Source Nodes: [hidden_states, hidden_states_1, to_1], Original ATen: [aten._to_copy, aten.mul, aten.sum]
110
+ # Source node to ATen node mapping:
111
+ # hidden_states => convert_element_type
112
+ # hidden_states_1 => mul_16
113
+ # to_1 => convert_element_type_1
114
+ # Graph fragment:
115
+ # %buf0 : Tensor "f32[1, 1, s33, 32][32*s33, 32*s33, 1, s33]cuda:7" = PlaceHolder[target=buf0]
116
+ # %convert_element_type : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:7"[num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%primals_4, torch.float32), kwargs = {})
117
+ # %mul_16 : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type, %rsqrt), kwargs = {})
118
+ # %convert_element_type_1 : Tensor "bf16[s47, s87, s33][s33*s87, s33, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_16, torch.bfloat16), kwargs = {})
119
+ # %mul_28 : Tensor "bf16[s47, s87, s33][s33*s87, s33, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_1, %convert_element_type_1), kwargs = {})
120
+ # %sum_1 : Tensor "bf16[1, 1, s33][s33, s33, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_28, [0, 1], True), kwargs = {})
121
+ # return %sum_1
122
+ triton_per_fused__to_copy_mul_sum_1 = async_compile.triton('triton_per_fused__to_copy_mul_sum_1', '''
123
+ import triton
124
+ import triton.language as tl
125
+
126
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
127
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
128
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
129
+ triton_helpers.set_driver_to_gpu()
130
+
131
+ @triton_heuristics.persistent_reduction(
132
+ size_hints={'x': 4096, 'r0_': 32},
133
+ reduction_hint=ReductionHint.OUTER,
134
+ filename=__file__,
135
+ triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*bf16', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
136
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_mul_sum_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
137
+ )
138
+ @triton.jit
139
+ def triton_per_fused__to_copy_mul_sum_1(in_ptr0, out_ptr0, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr):
140
+ r0_numel = 32
141
+ R0_BLOCK: tl.constexpr = 32
142
+ rnumel = r0_numel
143
+ RBLOCK: tl.constexpr = R0_BLOCK
144
+ xoffset = tl.program_id(0) * XBLOCK
145
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
146
+ xmask = xindex < xnumel
147
+ r0_index = tl.arange(0, R0_BLOCK)[None, :]
148
+ r0_offset = 0
149
+ r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
150
+ roffset = r0_offset
151
+ rindex = r0_index
152
+ r0_1 = r0_index
153
+ x0 = xindex
154
+ tmp0 = tl.load(in_ptr0 + (x0 + ks0*r0_1), xmask, other=0.0)
155
+ tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
156
+ tmp3 = tl.where(xmask, tmp1, 0)
157
+ tmp4 = tl.sum(tmp3, 1)[:, None].to(tl.float32)
158
+ tl.store(out_ptr0 + (x0), tmp4, xmask)
159
+ ''', device_str='cuda')
160
+
161
+
162
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/im/cimq7s4zgz63carjnhuvinchsq4odrr475l6qsymkihvbxvheq7a.py
163
+ # Topologically Sorted Source Nodes: [hidden_states], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.pow, aten.expand, aten.div, aten.add]
164
+ # Source node to ATen node mapping:
165
+ # hidden_states => convert_element_type
166
+ # Graph fragment:
167
+ # %tangents_1 : Tensor "bf16[s47, s87, s33][s33*s87, s33, 1]cuda:7" = PlaceHolder[target=tangents_1]
168
+ # %primals_7 : Tensor "bf16[s33][1]cuda:7" = PlaceHolder[target=primals_7]
169
+ # %primals_4 : Tensor "bf16[s47, s87, s33][s33*s87, s33, 1]cuda:7" = PlaceHolder[target=primals_4]
170
+ # %rsqrt : Tensor "f32[s47, s87, 1][s87, 1, 1]cuda:7" = PlaceHolder[target=rsqrt]
171
+ # %sum_2 : Tensor "f32[s47, s87, 1][s87, 1, s47*s87]cuda:7" = PlaceHolder[target=sum_2]
172
+ # %mul_27 : Tensor "bf16[s47, s87, s33][s33*s87, s33, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_1, %primals_7), kwargs = {})
173
+ # %convert_element_type : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:7"[num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%primals_4, torch.float32), kwargs = {})
174
+ # %convert_element_type_2 : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:7"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_27, torch.float32), kwargs = {})
175
+ # %mul_29 : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_2, %convert_element_type), kwargs = {})
176
+ # %mul_30 : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_2, %rsqrt), kwargs = {})
177
+ # %sum_2 : Tensor "f32[s47, s87, 1][s87, 1, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_29, [2], True), kwargs = {})
178
+ # %pow_2 : Tensor "f32[s47, s87, 1][s87, 1, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%rsqrt, 3), kwargs = {})
179
+ # %mul_31 : Tensor "f32[s47, s87, 1][s87, 1, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%sum_2, -0.5), kwargs = {})
180
+ # %mul_32 : Tensor "f32[s47, s87, 1][s87, 1, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_31, %pow_2), kwargs = {})
181
+ # %expand : Tensor "f32[s47, s87, s33][s87, 1, 0]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%mul_32, [%primals_1, %primals_2, %primals_3]), kwargs = {})
182
+ # %div : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand, %primals_3), kwargs = {})
183
+ # %pow_3 : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type, 1.0), kwargs = {})
184
+ # %mul_33 : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_3, 2.0), kwargs = {})
185
+ # %mul_34 : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div, %mul_33), kwargs = {})
186
+ # %add_37 : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_30, %mul_34), kwargs = {})
187
+ # %convert_element_type_3 : Tensor "bf16[s47, s87, s33][s33*s87, s33, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_37, torch.bfloat16), kwargs = {})
188
+ # return %sum_2,%convert_element_type_3
189
+ triton_red_fused__to_copy_add_div_expand_mul_pow_sum_2 = async_compile.triton('triton_red_fused__to_copy_add_div_expand_mul_pow_sum_2', '''
190
+ import triton
191
+ import triton.language as tl
192
+
193
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
194
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
195
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
196
+ triton_helpers.set_driver_to_gpu()
197
+
198
+ @triton_heuristics.reduction(
199
+ size_hints={'x': 4096, 'r0_': 4096},
200
+ reduction_hint=ReductionHint.INNER,
201
+ filename=__file__,
202
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
203
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_div_expand_mul_pow_sum_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 7, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
204
+ )
205
+ @triton.jit
206
+ def triton_red_fused__to_copy_add_div_expand_mul_pow_sum_2(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
207
+ rnumel = r0_numel
208
+ RBLOCK: tl.constexpr = R0_BLOCK
209
+ xoffset = tl.program_id(0) * XBLOCK
210
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
211
+ xmask = xindex < xnumel
212
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
213
+ rbase = r0_base
214
+ x0 = xindex
215
+ _tmp8 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
216
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
217
+ r0_index = r0_offset + r0_base
218
+ r0_mask = r0_index < r0_numel
219
+ roffset = r0_offset
220
+ rindex = r0_index
221
+ r0_1 = r0_index
222
+ tmp0 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
223
+ tmp1 = tl.load(in_ptr1 + (r0_1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
224
+ tmp4 = tl.load(in_ptr2 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
225
+ tmp2 = tmp0 * tmp1
226
+ tmp3 = tmp2.to(tl.float32)
227
+ tmp5 = tmp4.to(tl.float32)
228
+ tmp6 = tmp3 * tmp5
229
+ tmp7 = tl.broadcast_to(tmp6, [XBLOCK, R0_BLOCK])
230
+ tmp9 = _tmp8 + tmp7
231
+ _tmp8 = tl.where(r0_mask & xmask, tmp9, _tmp8)
232
+ tmp8 = tl.sum(_tmp8, 1)[:, None]
233
+ tmp14 = tl.load(in_ptr3 + (x0), xmask, eviction_policy='evict_last')
234
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
235
+ r0_index = r0_offset + r0_base
236
+ r0_mask = r0_index < r0_numel
237
+ roffset = r0_offset
238
+ rindex = r0_index
239
+ r0_1 = r0_index
240
+ tmp10 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
241
+ tmp11 = tl.load(in_ptr1 + (r0_1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
242
+ tmp24 = tl.load(in_ptr2 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
243
+ tmp12 = tmp10 * tmp11
244
+ tmp13 = tmp12.to(tl.float32)
245
+ tmp15 = tmp13 * tmp14
246
+ tmp16 = -0.5
247
+ tmp17 = tmp8 * tmp16
248
+ tmp18 = tmp14 * tmp14
249
+ tmp19 = tmp18 * tmp14
250
+ tmp20 = tmp17 * tmp19
251
+ tmp21 = ks0
252
+ tmp22 = tmp21.to(tl.float32)
253
+ tmp23 = (tmp20 / tmp22)
254
+ tmp25 = tmp24.to(tl.float32)
255
+ tmp26 = 2.0
256
+ tmp27 = tmp25 * tmp26
257
+ tmp28 = tmp23 * tmp27
258
+ tmp29 = tmp15 + tmp28
259
+ tmp30 = tmp29.to(tl.float32)
260
+ tl.store(out_ptr1 + (r0_1 + ks0*x0), tmp30, r0_mask & xmask)
261
+ ''', device_str='cuda')
262
+
263
+
264
+ async_compile.wait(globals())
265
+ del async_compile
266
+
267
+ class Runner:
268
+ def __init__(self, partitions):
269
+ self.partitions = partitions
270
+
271
+ def recursively_apply_fns(self, fns):
272
+ new_callables = []
273
+ for fn, c in zip(fns, self.partitions):
274
+ new_callables.append(fn(c))
275
+ self.partitions = new_callables
276
+
277
+ def call(self, args):
278
+ primals_1, primals_2, primals_3, primals_6, primals_4, primals_7, rsqrt, tangents_1 = args
279
+ args.clear()
280
+ s47 = primals_1
281
+ s87 = primals_2
282
+ s33 = primals_3
283
+ s82 = primals_6
284
+ assert_size_stride(primals_4, (s47, s87, s33), (s33*s87, s33, 1))
285
+ assert_size_stride(primals_7, (s33, ), (1, ))
286
+ assert_size_stride(rsqrt, (s47, s87, 1), (s87, 1, 1))
287
+ assert_size_stride(tangents_1, (s47, s87, s33), (s33*s87, s33, 1))
288
+ with torch.cuda._DeviceGuard(7):
289
+ torch.cuda.set_device(7)
290
+ buf0 = empty_strided_cuda((1, 1, s33, 32), (32*s33, 32*s33, 1, s33), torch.float32)
291
+ # Topologically Sorted Source Nodes: [hidden_states, hidden_states_1, to_1], Original ATen: [aten._to_copy, aten.mul, aten.sum]
292
+ triton_red_fused__to_copy_mul_sum_0_xnumel = 32*s33
293
+ triton_red_fused__to_copy_mul_sum_0_r0_numel = (31 + s47*s87) // 32
294
+ stream7 = get_raw_stream(7)
295
+ triton_red_fused__to_copy_mul_sum_0.run(tangents_1, primals_4, rsqrt, buf0, s33, s47, s87, triton_red_fused__to_copy_mul_sum_0_xnumel, triton_red_fused__to_copy_mul_sum_0_r0_numel, stream=stream7)
296
+ buf1 = empty_strided_cuda((1, 1, s33), (s33, s33, 1), torch.bfloat16)
297
+ # Topologically Sorted Source Nodes: [hidden_states, hidden_states_1, to_1], Original ATen: [aten._to_copy, aten.mul, aten.sum]
298
+ stream7 = get_raw_stream(7)
299
+ triton_per_fused__to_copy_mul_sum_1.run(buf0, buf1, s33, s33, 32, stream=stream7)
300
+ del buf0
301
+ buf3 = empty_strided_cuda((s47, s87, s33), (s33*s87, s33, 1), torch.bfloat16)
302
+ # Topologically Sorted Source Nodes: [hidden_states], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.pow, aten.expand, aten.div, aten.add]
303
+ triton_red_fused__to_copy_add_div_expand_mul_pow_sum_2_xnumel = s47*s87
304
+ stream7 = get_raw_stream(7)
305
+ triton_red_fused__to_copy_add_div_expand_mul_pow_sum_2.run(tangents_1, primals_7, primals_4, rsqrt, buf3, s33, triton_red_fused__to_copy_add_div_expand_mul_pow_sum_2_xnumel, s33, stream=stream7)
306
+ del primals_4
307
+ del primals_7
308
+ del rsqrt
309
+ del tangents_1
310
+ return (None, None, None, buf3, None, None, reinterpret_tensor(buf1, (s33, ), (1, ), 0), )
311
+
312
+ runner = Runner(partitions=[])
313
+ call = runner.call
314
+ recursively_apply_fns = runner.recursively_apply_fns
315
+
316
+
317
+ def benchmark_compiled_module(times=10, repeat=10):
318
+ from torch._dynamo.testing import rand_strided
319
+ from torch._inductor.utils import print_performance
320
+ primals_1 = 2
321
+ primals_2 = 2048
322
+ primals_3 = 4096
323
+ primals_6 = 840433664
324
+ primals_4 = rand_strided((2, 2048, 4096), (8388608, 4096, 1), device='cuda:7', dtype=torch.bfloat16)
325
+ primals_7 = rand_strided((4096, ), (1, ), device='cuda:7', dtype=torch.bfloat16)
326
+ rsqrt = rand_strided((2, 2048, 1), (2048, 1, 1), device='cuda:7', dtype=torch.float32)
327
+ tangents_1 = rand_strided((2, 2048, 4096), (8388608, 4096, 1), device='cuda:7', dtype=torch.bfloat16)
328
+ fn = lambda: call([primals_1, primals_2, primals_3, primals_6, primals_4, primals_7, rsqrt, tangents_1])
329
+ return print_performance(fn, times=times, repeat=repeat)
330
+
331
+
332
+ if __name__ == "__main__":
333
+ from torch._inductor.wrapper_benchmark import compiled_module_main
334
+ compiled_module_main('None', benchmark_compiled_module)
SpecForge-ext/cache/compiled_kernels/6j/b801eb968d13baeef00c09ffebb7c203c75661545f70c7ec4ed906e946ad8a67.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 256, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 11, "triton_cache_hash": "XAIV2GWX5UZL7NNOCKNWC2I6AATKI6664P6FTQPRXS2M4AR4WJWA"}
SpecForge-ext/cache/compiled_kernels/6j/c6jx5fvfijye7zqqg42xonpcdfuwatv7bizrwompd5o3dua57uju.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.pointwise(
11
+ size_hints={'x': 8192},
12
+ filename=__file__,
13
+ triton_meta={'signature': {'out_ptr0': '*i32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_new_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
15
+ min_elem_per_thread=0
16
+ )
17
+ @triton.jit
18
+ def triton_poi_fused_new_zeros_0(out_ptr0, xnumel, XBLOCK : tl.constexpr):
19
+ xoffset = tl.program_id(0) * XBLOCK
20
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
21
+ xmask = xindex < xnumel
22
+ x0 = xindex
23
+ tmp0 = tl.full([1], 0, tl.int32)
24
+ tl.store(out_ptr0 + (x0), tmp0, xmask)
SpecForge-ext/cache/compiled_kernels/6o/c6obqatzdeyb7elxstetxuvmlhbvwph6buxkixqs4flvdn2x6vgl.py ADDED
@@ -0,0 +1,835 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+
9
+ @triton_heuristics.template(
10
+
11
+ num_stages=3,
12
+ num_warps=8,
13
+ triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
15
+
16
+ )
17
+ @triton.jit
18
+ def triton_flex_attention_backward(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3):
19
+ PRESCALE_QK : tl.constexpr = False
20
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
21
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
22
+ WRITE_DQ : tl.constexpr = True
23
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
24
+ OUTPUT_MAX : tl.constexpr = False
25
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
26
+ IS_DIVISIBLE : tl.constexpr = False
27
+ SM_SCALE : tl.constexpr = 0.08838834764831843
28
+ GQA_SHARED_HEADS : tl.constexpr = 4
29
+ HAS_FULL_BLOCKS : tl.constexpr = True
30
+ QK_HEAD_DIM : tl.constexpr = 128
31
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
32
+ V_HEAD_DIM : tl.constexpr = 128
33
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
34
+ SAFE_HEAD_DIM : tl.constexpr = True
35
+ BLOCK_M1 : tl.constexpr = 64
36
+ BLOCK_N1 : tl.constexpr = 128
37
+ BLOCK_M2 : tl.constexpr = 128
38
+ BLOCK_N2 : tl.constexpr = 64
39
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
40
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
41
+ INDEX_DTYPE : tl.constexpr = tl.int32
42
+ Q = arg_Q
43
+ K = arg_K
44
+ V = arg_V
45
+ LSE = arg_LSE
46
+ DELTA = arg_DELTA
47
+ DO = arg_DO
48
+ DQ = arg_DQ
49
+ DV = arg_DV
50
+ KV_NUM_BLKS = arg_KV_NUM_BLKS
51
+ KV_IDX = arg_KV_IDX
52
+ Q_NUM_BLKS = arg_Q_NUM_BLKS
53
+ Q_IDX = arg_Q_IDX
54
+ FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
55
+ FULL_KV_IDX = arg_FULL_KV_IDX
56
+ FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS
57
+ FULL_Q_IDX = arg_FULL_Q_IDX
58
+
59
+ # Sub notation for this kernel:
60
+ #
61
+ # Q: Query, K: Key, V: Value
62
+ # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype)
63
+ # DELTA: Precomputed sum(OUT*DO, axis=-1)
64
+ # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value
65
+ # DK: Derivative of Key, is the written to via the store_output call due to some limitations with
66
+ # inductor codegen
67
+ # M: Number of queries, N: Number of keys/values
68
+ # QK_HEAD_DIM: The dimension of the query and key embeddings
69
+ # V_HEAD_DIM: The dimension of the value embeddings
70
+ # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim
71
+ # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
72
+ # (Modifiable) Performance tuning options
73
+ # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block.
74
+ # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V.
75
+ # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q.
76
+ # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block.
77
+ #
78
+ # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
79
+ # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
80
+ # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
81
+ # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query.
82
+ # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query.
83
+ # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
84
+ # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
85
+ # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query.
86
+ # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query.
87
+
88
+ # The below are kernel options that can be applied for certain score_mods,
89
+ # or involve a numerics vs. perf tradeoff
90
+ # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
91
+ # about 20% more numerical error, but slightly faster.
92
+
93
+ # Define strides of inputs
94
+ stride_qz, stride_qh, stride_qm, stride_qd = 8388608, 128, 4096, 1
95
+ stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks0, 128*ks0, 128, 1
96
+ stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks0, 128*ks0, 128, 1
97
+ stride_doz, stride_doh, stride_dom, stride_dod = 8388608, 262144, 128, 1
98
+
99
+ stride_dqz, stride_dqh, stride_dqm, stride_dqd = 8388608, 128, 4096, 1
100
+ stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks0, 128*ks0, 128, 1
101
+
102
+ ZQ = 2
103
+ HQ = 32
104
+ HKV = 8
105
+ Q_LEN = 2048
106
+ ZKV = 2
107
+ KV_LEN = ks0
108
+
109
+ MATMUL_PRECISION = Q.dtype.element_ty
110
+
111
+ pid = tl.program_id(0).to(INDEX_DTYPE)
112
+ NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1)
113
+ NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2)
114
+
115
+ off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx
116
+ off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx
117
+ off_zkv = off_zq % ZKV # kv batch idx
118
+
119
+ SPARSE_Z = 2
120
+ SPARSE_HQ = 1
121
+
122
+ sparse_idx_z = off_zq % SPARSE_Z
123
+
124
+ k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64)
125
+ v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64)
126
+ # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
127
+ # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
128
+ dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64)
129
+
130
+ # offset K, V, DV pointers for batch/kv-head
131
+ K += k_adj
132
+ V += v_adj
133
+ DV += dv_adj
134
+
135
+ RCP_LN2 = 1.44269504
136
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
137
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
138
+
139
+ if pid >= NUM_KV_BLOCKS:
140
+ off_pid = pid - NUM_KV_BLOCKS
141
+ # THIS BLOCK DOES DQ
142
+ SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2)
143
+ SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
144
+ off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS
145
+ start_m2_block = off_pid % NUM_Q_BLOCKS
146
+ off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE
147
+ stride_kv_num_blks_h = 16
148
+ stride_kv_idx_h = 16*ks1
149
+ stride_kv_idx_m = ks1
150
+
151
+ sparse_idx_hq2 = off_hq2 % SPARSE_HQ
152
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2
153
+
154
+ sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask
155
+ sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950
156
+
157
+ # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
158
+ q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64)
159
+ do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64)
160
+ dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64)
161
+ off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64)
162
+
163
+ Q2 = Q + q_adj2
164
+ DO2 = DO + do_adj2
165
+ # TODO: This does not work if DQ is not the same layout as Q (for example,
166
+ # if Q is broadcasted)
167
+ DQ2 = DQ + dq_adj2
168
+ LSE2 = LSE + off_chz2
169
+ DELTA2 = DELTA + off_chz2
170
+
171
+ # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32)
172
+ dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
173
+
174
+ start_m2 = start_m2_block * BLOCK_M2
175
+ offs_m2 = start_m2 + tl.arange(0, BLOCK_M2)
176
+
177
+ # load Q and do: they stay in SRAM throughout the inner loop.
178
+ q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
179
+ do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
180
+
181
+ if PRESCALE_QK:
182
+ q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
183
+
184
+ if IS_DIVISIBLE:
185
+ Di = tl.load(DELTA2 + offs_m2)
186
+ lse = tl.load(LSE2 + offs_m2)
187
+ else:
188
+ Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN)
189
+ lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN)
190
+ lse = tl.where(lse == -float("inf"), 0.0, lse)
191
+ lse = lse[:, None]
192
+
193
+ # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
194
+ # KV_IDX and KV_NUM_BLKS are always contiguous.
195
+ kv_indices = KV_IDX + sparse_kv_idx_offset
196
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
197
+ sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
198
+
199
+ offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
200
+ dq = bwd_dq_inner(
201
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
202
+ K, V,
203
+ dq, q, do, Di, lse,
204
+ off_zq, off_hq2, offs_m2, offs_n2,
205
+ stride_kn, stride_kd, stride_vn, stride_vd,
206
+ kv_indices, sparse_kv_num_blocks,
207
+ MATMUL_PRECISION,
208
+ IS_FULL_BLOCKS=False,
209
+ )
210
+
211
+ if HAS_FULL_BLOCKS:
212
+ # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
213
+ # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
214
+ kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
215
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
216
+ sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
217
+
218
+ offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
219
+ dq = bwd_dq_inner(
220
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
221
+ K, V,
222
+ dq, q, do, Di, lse,
223
+ off_zq, off_hq2, offs_m2, offs_n2,
224
+ stride_kn, stride_kd, stride_vn, stride_vd,
225
+ kv_indices, sparse_kv_num_blocks,
226
+ MATMUL_PRECISION,
227
+ IS_FULL_BLOCKS=True,
228
+ )
229
+
230
+ # Write back dQ.
231
+ dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd
232
+ dq *= SM_SCALE
233
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
234
+ tl.store(dq_ptrs, dq)
235
+ else:
236
+ tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM))
237
+ else:
238
+ # THIS BLOCK DOES DK & DV
239
+ SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
240
+ SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1)
241
+
242
+ pid_mask = pid // SPARSE_KV_MULTIPLE
243
+
244
+ stride_q_num_blks_h = ks2
245
+ stride_q_idx_h = 16*ks3
246
+ stride_q_idx_n = 16
247
+
248
+
249
+ dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
250
+ dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
251
+
252
+ start_n1 = pid * BLOCK_N1
253
+ offs_n1 = start_n1 + tl.arange(0, BLOCK_N1)
254
+
255
+ # load K and V: they stay in SRAM throughout the inner loop.
256
+ k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
257
+ v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
258
+
259
+ if PRESCALE_QK:
260
+ k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
261
+
262
+ for off_g in range(0, GQA_SHARED_HEADS):
263
+ off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g
264
+
265
+ # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
266
+ q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64)
267
+ do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64)
268
+ dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64)
269
+ off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64)
270
+
271
+ Q1 = Q + q_adj1
272
+ DO1 = DO + do_adj1
273
+ # TODO: This does not work if DQ is not the same layout as Q (for example,
274
+ # if Q is broadcasted)
275
+ LSE1 = LSE + off_chz1
276
+ DELTA1 = DELTA + off_chz1
277
+
278
+ sparse_idx_hq1 = off_hq1 % SPARSE_HQ
279
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1
280
+
281
+ sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask
282
+ sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950
283
+
284
+ # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
285
+ # Q_IDX and Q_NUM_BLKS are always contiguous.
286
+ q_indices = Q_IDX + sparse_q_idx_offset
287
+ q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
288
+ sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset)
289
+
290
+ offs_m1 = q_start + tl.arange(0, BLOCK_M1)
291
+ dk, dv = bwd_dkdv_inner(
292
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
293
+ Q1, DO1, DELTA1, LSE1,
294
+ dk, dv, k, v,
295
+ off_zq, off_hq1, offs_n1, offs_m1,
296
+ stride_qm, stride_qd, stride_dom, stride_dod,
297
+ q_indices, sparse_q_num_blocks,
298
+ MATMUL_PRECISION,
299
+ IS_FULL_BLOCKS=False,
300
+ )
301
+
302
+
303
+ if HAS_FULL_BLOCKS:
304
+ # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
305
+ # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous.
306
+ q_indices = FULL_Q_IDX + sparse_q_idx_offset
307
+ q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
308
+ sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset)
309
+
310
+ offs_m1 = q_start + tl.arange(0, BLOCK_M1)
311
+ dk, dv = bwd_dkdv_inner(
312
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
313
+ Q1, DO1, DELTA1, LSE1,
314
+ dk, dv, k, v,
315
+ off_zq, off_hq1, offs_n1, offs_m1,
316
+ stride_qm, stride_qd, stride_dom, stride_dod,
317
+ q_indices, sparse_q_num_blocks,
318
+ MATMUL_PRECISION,
319
+ IS_FULL_BLOCKS=True,
320
+ )
321
+
322
+ # Write back dV and dK.
323
+ dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd
324
+
325
+ index_n = offs_n1[:, None]
326
+ index_k = offs_k[None, :]
327
+ index_v = offs_v[None, :]
328
+
329
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
330
+ tl.store(dv_ptrs, dv)
331
+ else:
332
+ tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM))
333
+
334
+ dk *= SM_SCALE
335
+
336
+ if SAFE_HEAD_DIM:
337
+ mask = index_n < KV_LEN
338
+ else:
339
+ mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM)
340
+
341
+ # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
342
+ # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
343
+ tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED])
344
+ xindex = index_k + 128*index_n + 128*off_hkv*ks0 + 1024*off_zq*ks0
345
+ tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask)
346
+
347
+ @triton.jit
348
+ def bwd_dq_inner(
349
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
350
+ K, V, # pointers
351
+ dq, q, do, Di, lse,
352
+ off_z, off_hq, offs_m2, offs_n2,
353
+ stride_kn, stride_kd, stride_vn, stride_vd,
354
+ kv_indices, sparse_kv_num_blocks,
355
+ MATMUL_PRECISION,
356
+ IS_FULL_BLOCKS,
357
+ ):
358
+ PRESCALE_QK : tl.constexpr = False
359
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
360
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
361
+ WRITE_DQ : tl.constexpr = True
362
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
363
+ OUTPUT_MAX : tl.constexpr = False
364
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
365
+ IS_DIVISIBLE : tl.constexpr = False
366
+ SM_SCALE : tl.constexpr = 0.08838834764831843
367
+ GQA_SHARED_HEADS : tl.constexpr = 4
368
+ HAS_FULL_BLOCKS : tl.constexpr = True
369
+ QK_HEAD_DIM : tl.constexpr = 128
370
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
371
+ V_HEAD_DIM : tl.constexpr = 128
372
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
373
+ SAFE_HEAD_DIM : tl.constexpr = True
374
+ BLOCK_M1 : tl.constexpr = 64
375
+ BLOCK_N1 : tl.constexpr = 128
376
+ BLOCK_M2 : tl.constexpr = 128
377
+ BLOCK_N2 : tl.constexpr = 64
378
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
379
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
380
+ INDEX_DTYPE : tl.constexpr = tl.int32
381
+
382
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
383
+ RCP_LN2: tl.constexpr = 1.44269504
384
+ Q_LEN = 2048
385
+ KV_LEN = ks0
386
+
387
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
388
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
389
+
390
+ kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd
391
+ vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd
392
+ # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
393
+ tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
394
+
395
+ hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1))
396
+
397
+ for start_n in range(0, hi):
398
+ dq = bwd_dq_block_mn(
399
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
400
+ dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
401
+ off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
402
+ stride_kn, stride_kd, stride_vn, stride_vd,
403
+ kv_indices, sparse_kv_num_blocks,
404
+ MATMUL_PRECISION, RCP_LN2,
405
+ IS_FULL_BLOCKS,
406
+ )
407
+
408
+ # Increment pointers.
409
+ offset = get_offset_for_next_block(
410
+ start_n, kv_indices, sparse_kv_num_blocks,
411
+ SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS
412
+ )
413
+
414
+ kT_ptrs += offset * stride_kn
415
+ vT_ptrs += offset * stride_vn
416
+
417
+ offs_n2 += offset
418
+
419
+ return dq
420
+
421
+
422
+ @triton.jit
423
+ def bwd_dq_block_mn(
424
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
425
+ dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
426
+ off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
427
+ stride_kn, stride_kd, stride_vn, stride_vd,
428
+ kv_indices, sparse_kv_num_blocks,
429
+ MATMUL_PRECISION, RCP_LN2,
430
+ IS_FULL_BLOCKS,
431
+ ):
432
+ PRESCALE_QK : tl.constexpr = False
433
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
434
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
435
+ WRITE_DQ : tl.constexpr = True
436
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
437
+ OUTPUT_MAX : tl.constexpr = False
438
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
439
+ IS_DIVISIBLE : tl.constexpr = False
440
+ SM_SCALE : tl.constexpr = 0.08838834764831843
441
+ GQA_SHARED_HEADS : tl.constexpr = 4
442
+ HAS_FULL_BLOCKS : tl.constexpr = True
443
+ QK_HEAD_DIM : tl.constexpr = 128
444
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
445
+ V_HEAD_DIM : tl.constexpr = 128
446
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
447
+ SAFE_HEAD_DIM : tl.constexpr = True
448
+ BLOCK_M1 : tl.constexpr = 64
449
+ BLOCK_N1 : tl.constexpr = 128
450
+ BLOCK_M2 : tl.constexpr = 128
451
+ BLOCK_N2 : tl.constexpr = 64
452
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
453
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
454
+ INDEX_DTYPE : tl.constexpr = tl.int32
455
+
456
+
457
+ # NB reversed order to since K is transposed
458
+ kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN)
459
+ qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION)
460
+ if not PRESCALE_QK:
461
+ qk *= SM_SCALE
462
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
463
+ pre_mod_scores = qk
464
+ n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None)
465
+ # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim
466
+ # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary
467
+ m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None)
468
+
469
+ tmp0 = (qk)
470
+ post_mod_scores = tmp0
471
+
472
+
473
+
474
+
475
+ if not IS_DIVISIBLE:
476
+ post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf"))
477
+
478
+ if not IS_FULL_BLOCKS:
479
+ tmp1 = tl.full([1], False, tl.int1)
480
+ tmp2 = (m)
481
+ tmp3 = (n)
482
+ tmp4 = tmp2 >= tmp3
483
+ tmp5 = tmp3.to(tl.int64)
484
+ tmp6 = (off_z)
485
+ tmp7 = tl.load(in_ptr16 + tmp6)
486
+ tmp8 = tmp5 < tmp7
487
+ tmp9 = tmp2.to(tl.int64)
488
+ tmp10 = tmp9 < tmp7
489
+ tmp11 = tmp8 & tmp10
490
+ tmp12 = tmp4 & tmp11
491
+ tmp13 = tmp1 | tmp12
492
+ tmp14 = tl.full([1], 2048, tl.int32)
493
+ tmp15 = tmp3 >= tmp14
494
+ tmp16 = (tmp3 % tmp14)
495
+ tmp17 = tl.full([1], 0, tl.int32)
496
+ tmp18 = tmp16 != tmp17
497
+ tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0
498
+ tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0
499
+ tmp21 = tmp19 != tmp20
500
+ tmp22 = tmp18 & tmp21
501
+ tmp23 = tmp16 + tmp14
502
+ tmp24 = tl.where(tmp22, tmp23, tmp16)
503
+ tmp25 = tmp24.to(tl.int64)
504
+ tmp26 = tmp25 < tmp7
505
+ tmp27 = tmp15 & tmp26
506
+ tmp28 = tmp3 - tmp2
507
+ tmp29 = (tmp28 % tmp14)
508
+ tmp30 = tmp29 != tmp17
509
+ tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0
510
+ tmp32 = tmp31 != tmp20
511
+ tmp33 = tmp30 & tmp32
512
+ tmp34 = tmp29 + tmp14
513
+ tmp35 = tl.where(tmp33, tmp34, tmp29)
514
+ tmp36 = tmp35 == tmp17
515
+ tmp37 = tmp27 & tmp36
516
+ tmp38 = tmp13 | tmp37
517
+ mask_mod_output = tmp38
518
+
519
+
520
+ # apply mask for partial masked block
521
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
522
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
523
+ if not PRESCALE_QK:
524
+ post_mod_scores *= RCP_LN2
525
+ p = tl.math.exp2(post_mod_scores - lse)
526
+ # Compute dP and dS.
527
+ # NB reversed order to since V is transposed
528
+ vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN)
529
+
530
+ dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION)
531
+ ds = p * (dp - Di[:, None])
532
+ # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
533
+ tmp39 = (ds)
534
+ grad_scores = tmp39
535
+
536
+
537
+ if not IS_DIVISIBLE:
538
+ grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0)
539
+
540
+ # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
541
+ if WRITE_DQ:
542
+ scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN)
543
+
544
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
545
+ ds = grad_scores
546
+
547
+ if not IS_FULL_BLOCKS:
548
+ # (grads) apply mask for partially unmasked block
549
+ ds = tl.where(mask_mod_output, ds, 0.0)
550
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
551
+ ds = ds.to(MATMUL_PRECISION)
552
+ # Compute dQ.
553
+ dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION)
554
+
555
+ return dq
556
+
557
+
558
+ @triton.jit
559
+ def bwd_dkdv_inner(
560
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
561
+ Q, DO, DELTA, LSE, # pointers
562
+ dk, dv, k, v,
563
+ off_z, off_hq, offs_n1, offs_m1,
564
+ stride_qm, stride_qd, stride_dom, stride_dod,
565
+ q_indices, sparse_q_num_blocks,
566
+ MATMUL_PRECISION,
567
+ IS_FULL_BLOCKS,
568
+ ):
569
+ PRESCALE_QK : tl.constexpr = False
570
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
571
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
572
+ WRITE_DQ : tl.constexpr = True
573
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
574
+ OUTPUT_MAX : tl.constexpr = False
575
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
576
+ IS_DIVISIBLE : tl.constexpr = False
577
+ SM_SCALE : tl.constexpr = 0.08838834764831843
578
+ GQA_SHARED_HEADS : tl.constexpr = 4
579
+ HAS_FULL_BLOCKS : tl.constexpr = True
580
+ QK_HEAD_DIM : tl.constexpr = 128
581
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
582
+ V_HEAD_DIM : tl.constexpr = 128
583
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
584
+ SAFE_HEAD_DIM : tl.constexpr = True
585
+ BLOCK_M1 : tl.constexpr = 64
586
+ BLOCK_N1 : tl.constexpr = 128
587
+ BLOCK_M2 : tl.constexpr = 128
588
+ BLOCK_N2 : tl.constexpr = 64
589
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
590
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
591
+ INDEX_DTYPE : tl.constexpr = tl.int32
592
+
593
+ SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
594
+ RCP_LN2: tl.constexpr = 1.44269504
595
+ Q_LEN = 2048
596
+ KV_LEN = ks0
597
+
598
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
599
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
600
+
601
+ qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd
602
+ do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod
603
+ # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
604
+ tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
605
+
606
+ # The minimum is needed to handle the case where we run with a super large
607
+ # SPARSE_BLOCK_SIZE (i.e. no block-mask!)
608
+ hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1))
609
+
610
+ for start_m in range(0, hi):
611
+ dk, dv = bwd_dkdv_block_mn(
612
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
613
+ dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
614
+ off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
615
+ stride_qm, stride_qd, stride_dom, stride_dod,
616
+ q_indices, sparse_q_num_blocks,
617
+ MATMUL_PRECISION, RCP_LN2,
618
+ IS_FULL_BLOCKS,
619
+ )
620
+ # Increment pointers.
621
+ offset = get_offset_for_next_block(
622
+ start_m, q_indices, sparse_q_num_blocks,
623
+ SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS
624
+ )
625
+
626
+ qT_ptrs += offset * stride_qm
627
+ do_ptrs += offset * stride_dom
628
+ offs_m1 += offset
629
+
630
+ return dk, dv
631
+
632
+
633
+ @triton.jit
634
+ def bwd_dkdv_block_mn(
635
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
636
+ dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
637
+ off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
638
+ stride_qm, stride_qd, stride_dom, stride_dod,
639
+ q_indices, sparse_q_num_blocks,
640
+ MATMUL_PRECISION, RCP_LN2,
641
+ IS_FULL_BLOCKS,
642
+ ):
643
+ PRESCALE_QK : tl.constexpr = False
644
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
645
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
646
+ WRITE_DQ : tl.constexpr = True
647
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
648
+ OUTPUT_MAX : tl.constexpr = False
649
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
650
+ IS_DIVISIBLE : tl.constexpr = False
651
+ SM_SCALE : tl.constexpr = 0.08838834764831843
652
+ GQA_SHARED_HEADS : tl.constexpr = 4
653
+ HAS_FULL_BLOCKS : tl.constexpr = True
654
+ QK_HEAD_DIM : tl.constexpr = 128
655
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
656
+ V_HEAD_DIM : tl.constexpr = 128
657
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
658
+ SAFE_HEAD_DIM : tl.constexpr = True
659
+ BLOCK_M1 : tl.constexpr = 64
660
+ BLOCK_N1 : tl.constexpr = 128
661
+ BLOCK_M2 : tl.constexpr = 128
662
+ BLOCK_N2 : tl.constexpr = 64
663
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
664
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
665
+ INDEX_DTYPE : tl.constexpr = tl.int32
666
+
667
+
668
+ # NB reversed order since Q is transposed
669
+ qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN)
670
+ # Load LSE before computing qk to reduce pipeline stall.
671
+ if IS_DIVISIBLE:
672
+ lse = tl.load(LSE + offs_m1)
673
+ else:
674
+ lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN)
675
+ lse = tl.where(lse == -float("inf"), 0.0, lse)
676
+ qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION)
677
+ if not PRESCALE_QK:
678
+ qkT *= SM_SCALE
679
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
680
+ m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None)
681
+ # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim
682
+ # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary
683
+ n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None)
684
+
685
+ pre_mod_scores = qkT
686
+ tmp40 = (qkT)
687
+ post_mod_scores = tmp40
688
+
689
+
690
+
691
+ if not IS_DIVISIBLE:
692
+ post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf"))
693
+
694
+ if not IS_FULL_BLOCKS:
695
+ tmp41 = tl.full([1], False, tl.int1)
696
+ tmp42 = (m)
697
+ tmp43 = (n)
698
+ tmp44 = tmp42 >= tmp43
699
+ tmp45 = tmp43.to(tl.int64)
700
+ tmp46 = (off_z)
701
+ tmp47 = tl.load(in_ptr16 + tmp46)
702
+ tmp48 = tmp45 < tmp47
703
+ tmp49 = tmp42.to(tl.int64)
704
+ tmp50 = tmp49 < tmp47
705
+ tmp51 = tmp48 & tmp50
706
+ tmp52 = tmp44 & tmp51
707
+ tmp53 = tmp41 | tmp52
708
+ tmp54 = tl.full([1], 2048, tl.int32)
709
+ tmp55 = tmp43 >= tmp54
710
+ tmp56 = (tmp43 % tmp54)
711
+ tmp57 = tl.full([1], 0, tl.int32)
712
+ tmp58 = tmp56 != tmp57
713
+ tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0
714
+ tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0
715
+ tmp61 = tmp59 != tmp60
716
+ tmp62 = tmp58 & tmp61
717
+ tmp63 = tmp56 + tmp54
718
+ tmp64 = tl.where(tmp62, tmp63, tmp56)
719
+ tmp65 = tmp64.to(tl.int64)
720
+ tmp66 = tmp65 < tmp47
721
+ tmp67 = tmp55 & tmp66
722
+ tmp68 = tmp43 - tmp42
723
+ tmp69 = (tmp68 % tmp54)
724
+ tmp70 = tmp69 != tmp57
725
+ tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0
726
+ tmp72 = tmp71 != tmp60
727
+ tmp73 = tmp70 & tmp72
728
+ tmp74 = tmp69 + tmp54
729
+ tmp75 = tl.where(tmp73, tmp74, tmp69)
730
+ tmp76 = tmp75 == tmp57
731
+ tmp77 = tmp67 & tmp76
732
+ tmp78 = tmp53 | tmp77
733
+ mask_mod_output = tmp78
734
+
735
+ # (grads) apply mask for fully masked block
736
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
737
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
738
+ if not PRESCALE_QK:
739
+ post_mod_scores *= RCP_LN2
740
+ pT = tl.math.exp2(post_mod_scores - lse[None, :])
741
+ do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
742
+ # Compute dV.
743
+ ppT = pT
744
+ dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION)
745
+ if IS_DIVISIBLE:
746
+ Di = tl.load(DELTA + offs_m1)
747
+ else:
748
+ Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN)
749
+ # Compute dP and dS.
750
+ dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION)
751
+ dsT = pT * (dpT - Di[None, :])
752
+ # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
753
+ tmp79 = (dsT)
754
+ grad_scores = tmp79
755
+
756
+
757
+
758
+ if not IS_DIVISIBLE:
759
+ grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0)
760
+
761
+ # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
762
+ if not WRITE_DQ:
763
+ idx_b = off_z
764
+ idx_h = off_hq
765
+ idx_m = m
766
+ idx_n = n
767
+ scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN)
768
+
769
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
770
+ dsT = grad_scores
771
+ if not IS_FULL_BLOCKS:
772
+ # (grads) apply mask for partially unmasked block
773
+ dsT = tl.where(mask_mod_output, dsT, 0.0)
774
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
775
+ dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION)
776
+
777
+ return dk, dv
778
+
779
+ # Utility triton funcs
780
+ @triton.jit
781
+ def get_offset_for_next_block(
782
+ loop_iter, col_indices, total_blocks,
783
+ SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
784
+ BLOCKS_ARE_CONTIGUOUS: tl.constexpr
785
+ ):
786
+ if BLOCKS_ARE_CONTIGUOUS:
787
+ return BLOCK
788
+ cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
789
+ cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
790
+ next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
791
+ needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
792
+ jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
793
+ offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
794
+ return offset
795
+
796
+ @triton.jit
797
+ def get_bounded_indices(indices, max_len=None):
798
+ return indices % max_len if max_len is not None else indices
799
+
800
+ @triton.jit
801
+ def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
802
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
803
+ return tl.load(block_ptr)
804
+ elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
805
+ return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
806
+ elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
807
+ return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
808
+ else:
809
+ return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
810
+
811
+ @triton.jit
812
+ def load_checked_2d(
813
+ ptr,
814
+ offs_m,
815
+ offs_n,
816
+ stride_m,
817
+ stride_n,
818
+ IS_DIVISIBLE_M: tl.constexpr,
819
+ IS_DIVISIBLE_N: tl.constexpr,
820
+ M_LEN: tl.constexpr,
821
+ N_LEN: tl.constexpr,
822
+ ):
823
+ # Calculate final pointer if strides are provided
824
+ if stride_m is not None and stride_n is not None:
825
+ ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
826
+
827
+ # Handle all masking cases
828
+ if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
829
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
830
+ elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
831
+ return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
832
+ elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
833
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
834
+ else: # Both divisible
835
+ return tl.load(ptr)
SpecForge-ext/cache/compiled_kernels/7g/59ff39d5526de7bb833fbd386ca3ce564bdaf6828f559a423e599b5ad90d0456.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 256, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 11, "triton_cache_hash": "6FB7I6IASCIGI3DSKLBL4Q2CXFFWPYWXW7AMHNUUDLPGKUCB3PDA"}
SpecForge-ext/cache/compiled_kernels/7m/c7mmadjna7dltm72lxvsoktdadnw2jtxufsj2eoflefh2r5jo4gq.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.pointwise(
11
+ size_hints={'x': 8192},
12
+ filename=__file__,
13
+ triton_meta={'signature': {'out_ptr0': '*i32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_new_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
15
+ min_elem_per_thread=0
16
+ )
17
+ @triton.jit
18
+ def triton_poi_fused_new_zeros_0(out_ptr0, xnumel, XBLOCK : tl.constexpr):
19
+ xoffset = tl.program_id(0) * XBLOCK
20
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
21
+ xmask = xindex < xnumel
22
+ x0 = xindex
23
+ tmp0 = tl.full([1], 0, tl.int32)
24
+ tl.store(out_ptr0 + (x0), tmp0, xmask)
SpecForge-ext/cache/compiled_kernels/7m/e130479b4d145e755b390ab3b709dd817d1548c0596f91391e7581de8609a9eb.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 256, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 11, "triton_cache_hash": "XAIV2GWX5UZL7NNOCKNWC2I6AATKI6664P6FTQPRXS2M4AR4WJWA"}
SpecForge-ext/cache/compiled_kernels/7o/c7oiol3zozs5oktlpjhg3lu46rhbgu3bqq6yibefmn2imo6bua5k.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.reduction(
11
+ size_hints={'x': 16384, 'r0_': 32768},
12
+ reduction_hint=ReductionHint.DEFAULT,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 262144, 'r0_': 2097152000}}
16
+ )
17
+ @triton.jit
18
+ def triton_red_fused_argmax_1(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
19
+ xnumel = 16384
20
+ r0_numel = 32000
21
+ rnumel = r0_numel
22
+ RBLOCK: tl.constexpr = R0_BLOCK
23
+ xoffset = tl.program_id(0) * XBLOCK
24
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
25
+ xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
26
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
27
+ rbase = r0_base
28
+ x0 = (xindex % 2048)
29
+ x1 = xindex // 2048
30
+ _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32)
31
+ _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32)
32
+ x3 = xindex
33
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
34
+ r0_index = r0_offset + r0_base
35
+ r0_mask = r0_index < r0_numel
36
+ roffset = r0_offset
37
+ rindex = r0_index
38
+ r0_2 = r0_index
39
+ tmp0 = tl.load(in_ptr0 + (r0_2 + 32000*x0 + 65760000*x1), r0_mask, eviction_policy='evict_first', other=0.0)
40
+ tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
41
+ _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index(
42
+ _tmp2, _tmp2_index, tmp1, rindex
43
+ )
44
+ _tmp2 = tl.where(r0_mask, _tmp2_next, _tmp2)
45
+ _tmp2_index = tl.where(r0_mask, _tmp2_index_next, _tmp2_index)
46
+ tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1)
47
+ tmp2 = tmp2_idx[:, None]
48
+ tl.store(out_ptr0 + (x3), tmp2, None)
SpecForge-ext/cache/compiled_kernels/7z/c7z2jbjub3aupgnechol65vkvi5ruwpylzosdbqvscdyxmreb3jy.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.persistent_reduction(
11
+ size_hints={'x': 32, 'r0_': 16},
12
+ reduction_hint=ReductionHint.DEFAULT,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr4': '*i32', 'out_ptr5': '*i32', 'out_ptr6': '*i32', 'out_ptr7': '*i32', 'out_ptr8': '*i32', 'out_ptr9': '*i32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2', 'mutated_arg_names': ['out_ptr7', 'out_ptr9'], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
16
+ )
17
+ @triton.jit
18
+ def triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2(in_ptr0, out_ptr4, out_ptr5, out_ptr6, out_ptr7, out_ptr8, out_ptr9, xnumel, r0_numel, XBLOCK : tl.constexpr):
19
+ xnumel = 32
20
+ r0_numel = 16
21
+ R0_BLOCK: tl.constexpr = 16
22
+ rnumel = r0_numel
23
+ RBLOCK: tl.constexpr = R0_BLOCK
24
+ xoffset = tl.program_id(0) * XBLOCK
25
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
26
+ xmask = xindex < xnumel
27
+ r0_index = tl.arange(0, R0_BLOCK)[None, :]
28
+ r0_offset = 0
29
+ r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
30
+ roffset = r0_offset
31
+ rindex = r0_index
32
+ r0_1 = r0_index
33
+ x0 = xindex
34
+ tmp0 = tl.load(in_ptr0 + (r0_1 + 16*x0), xmask, other=0.0)
35
+ tmp1 = tl.full([1, 1], 0, tl.int64)
36
+ tmp2 = tmp0 > tmp1
37
+ tmp3 = tl.full([1, 1], 16384, tl.int64)
38
+ tmp4 = tmp0 < tmp3
39
+ tmp5 = tmp2 & tmp4
40
+ tmp6 = tmp5.to(tl.int8)
41
+ tmp7 = tmp6.to(tl.int32)
42
+ tmp8 = r0_1
43
+ tmp9 = tmp8.to(tl.int16)
44
+ tmp10 = tl.broadcast_to(tmp7, [XBLOCK, R0_BLOCK])
45
+ tmp11 = tl.broadcast_to(tmp9, [XBLOCK, R0_BLOCK])
46
+ tmp12, tmp13, = triton_helpers.sort_with_index(tmp10, tmp11, None, 1, stable=True, descending=True)
47
+ tmp14 = tmp0 == tmp3
48
+ tmp15 = tmp14.to(tl.int8)
49
+ tmp16 = tmp15.to(tl.int32)
50
+ tmp17 = tl.broadcast_to(tmp16, [XBLOCK, R0_BLOCK])
51
+ tmp18, tmp19, = triton_helpers.sort_with_index(tmp17, tmp11, None, 1, stable=True, descending=True)
52
+ tmp20 = tmp7.to(tl.int64)
53
+ tmp21 = tl.broadcast_to(tmp20, [XBLOCK, R0_BLOCK])
54
+ tmp23 = tl.where(xmask, tmp21, 0)
55
+ tmp24 = tl.sum(tmp23, 1)[:, None].to(tl.int64)
56
+ tmp25 = tmp16.to(tl.int64)
57
+ tmp26 = tl.broadcast_to(tmp25, [XBLOCK, R0_BLOCK])
58
+ tmp28 = tl.where(xmask, tmp26, 0)
59
+ tmp29 = tl.sum(tmp28, 1)[:, None].to(tl.int64)
60
+ tmp30 = tmp24.to(tl.int32)
61
+ tmp31 = tmp29.to(tl.int32)
62
+ tmp32 = tmp13.to(tl.int64)
63
+ tmp33 = tmp32.to(tl.int32)
64
+ tmp34 = tmp8 < tmp30
65
+ tmp35 = tl.full([1, 1], 16, tl.int32)
66
+ tmp36 = tl.where(tmp34, tmp33, tmp35)
67
+ tmp37 = tl.full([XBLOCK, R0_BLOCK], 17, tl.int32)
68
+ tmp38 = tmp36 + tmp37
69
+ tmp39 = tmp36 < 0
70
+ tmp40 = tl.where(tmp39, tmp38, tmp36)
71
+ tl.device_assert(((0 <= tmp40) & (tmp40 < 17)) | ~(xmask), "index out of bounds: 0 <= tmp40 < 17")
72
+ tmp42 = tl.full([1, 1], 1, tl.int32)
73
+ tmp43 = tmp19.to(tl.int64)
74
+ tmp44 = tmp43.to(tl.int32)
75
+ tmp45 = tmp8 < tmp31
76
+ tmp46 = tl.where(tmp45, tmp44, tmp35)
77
+ tmp47 = tmp46 + tmp37
78
+ tmp48 = tmp46 < 0
79
+ tmp49 = tl.where(tmp48, tmp47, tmp46)
80
+ tl.device_assert(((0 <= tmp49) & (tmp49 < 17)) | ~(xmask), "index out of bounds: 0 <= tmp49 < 17")
81
+ tl.store(out_ptr4 + (x0), tmp30, xmask)
82
+ tl.store(out_ptr5 + (x0), tmp31, xmask)
83
+ tl.store(out_ptr6 + (r0_1 + 16*x0), tmp33, xmask)
84
+ tl.store(out_ptr7 + (tl.broadcast_to(tmp40 + 17*x0, [XBLOCK, R0_BLOCK])), tmp42, xmask)
85
+ tl.store(out_ptr8 + (r0_1 + 16*x0), tmp44, xmask)
86
+ tl.store(out_ptr9 + (tl.broadcast_to(tmp49 + 17*x0, [XBLOCK, R0_BLOCK])), tmp42, xmask)
SpecForge-ext/cache/compiled_kernels/7z/c7z6kbhlhnd55iz3suxpzcfjhjv7p7i2zelu2nitjoegrwczbdyf.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.reduction(
11
+ size_hints={'x': 131072, 'r0_': 128},
12
+ reduction_hint=ReductionHint.OUTER,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_mul_sum_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
16
+ )
17
+ @triton.jit
18
+ def triton_red_fused__to_copy_mul_sum_0(in_ptr0, in_ptr1, in_ptr2, out_ptr0, ks0, ks1, ks2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
19
+ rnumel = r0_numel
20
+ RBLOCK: tl.constexpr = R0_BLOCK
21
+ xoffset = tl.program_id(0) * XBLOCK
22
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
23
+ xmask = xindex < xnumel
24
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
25
+ rbase = r0_base
26
+ x1 = xindex // ks0
27
+ x0 = (xindex % ks0)
28
+ _tmp13 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
29
+ x3 = xindex
30
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
31
+ r0_index = r0_offset + r0_base
32
+ r0_mask = r0_index < r0_numel
33
+ roffset = r0_offset
34
+ rindex = r0_index
35
+ r0_2 = r0_index
36
+ tmp0 = r0_2 + x1*((31 + ks1*ks2) // 32)
37
+ tmp1 = ks1*ks2
38
+ tmp2 = tmp0 < tmp1
39
+ tmp3 = tl.load(in_ptr0 + (x0 + ks0*(((r0_2 + x1*((31 + ks1*ks2) // 32)) % (ks1*ks2)))), r0_mask & tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
40
+ tmp4 = tl.load(in_ptr1 + (x0 + ks0*(((r0_2 + x1*((31 + ks1*ks2) // 32)) % (ks1*ks2)))), r0_mask & tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
41
+ tmp5 = tmp4.to(tl.float32)
42
+ tmp6 = tl.load(in_ptr2 + (((r0_2 + x1*((31 + ks1*ks2) // 32)) % (ks1*ks2))), r0_mask & tmp2 & xmask, eviction_policy='evict_last', other=0.0)
43
+ tmp7 = tmp5 * tmp6
44
+ tmp8 = tmp7.to(tl.float32)
45
+ tmp9 = tmp3 * tmp8
46
+ tmp10 = tl.full(tmp9.shape, 0, tmp9.dtype)
47
+ tmp11 = tl.where(tmp2, tmp9, tmp10)
48
+ tmp12 = tl.broadcast_to(tmp11, [XBLOCK, R0_BLOCK])
49
+ tmp14 = _tmp13 + tmp12
50
+ _tmp13 = tl.where(r0_mask & xmask, tmp14, _tmp13)
51
+ tmp13 = tl.sum(_tmp13, 1)[:, None]
52
+ tl.store(out_ptr0 + (x3), tmp13, xmask)
SpecForge-ext/cache/compiled_kernels/7z/de291d239bdb6c33244f90904700e0423d0a8026bdcf04c4cb1f87b0edee041b.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 8, "num_warps": 2, "num_stages": 1, "configs_hash": "6fcabd0411a839b7b5d117b5e6638bd1b5d7bc3379312c678d803859f08278a9", "found_by_coordesc": false, "time_taken_ms": 26, "triton_cache_hash": "5P66Y2R4BAVAKI2AZ4OOKOSCRGKH7SKT5SYLTP4RXGBUCBAJIDZQ"}
SpecForge-ext/cache/compiled_kernels/aa/caa67m6yhgzsw5semsgkn3vvui6pjb2e2mxtfb5xyoo3c5qle6ao.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AOT ID: ['4_backward']
2
+ from ctypes import c_void_p, c_long, c_int
3
+ import torch
4
+ import math
5
+ import random
6
+ import os
7
+ import tempfile
8
+ from math import inf, nan
9
+ from cmath import nanj
10
+ from torch._inductor.hooks import run_intermediate_hooks
11
+ from torch._inductor.utils import maybe_profile
12
+ from torch._inductor.codegen.memory_planning import _align as align
13
+ from torch import device, empty_strided
14
+ from torch._inductor.async_compile import AsyncCompile
15
+ from torch._inductor.select_algorithm import extern_kernels
16
+ import triton
17
+ import triton.language as tl
18
+ from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
19
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
20
+
21
+ aten = torch.ops.aten
22
+ inductor_ops = torch.ops.inductor
23
+ _quantized = torch.ops._quantized
24
+ assert_size_stride = torch._C._dynamo.guards.assert_size_stride
25
+ assert_alignment = torch._C._dynamo.guards.assert_alignment
26
+ empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
27
+ empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
28
+ empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
29
+ empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
30
+ empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
31
+ reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
32
+ alloc_from_pool = torch.ops.inductor._alloc_from_pool
33
+ async_compile = AsyncCompile()
34
+ empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
35
+
36
+
37
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/n5/cn5h4iq6wlljobax2ulslga4k6zxontovelmyztexccj4qb2xkei.py
38
+ # Topologically Sorted Source Nodes: [squeeze_2, sin, getitem_1, sin_1, squeeze, cos, getitem, cos_1], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.slice_backward, aten.add]
39
+ # Source node to ATen node mapping:
40
+ # cos => squeeze_1
41
+ # cos_1 => unsqueeze
42
+ # getitem => index
43
+ # getitem_1 => index_1
44
+ # sin => squeeze_3
45
+ # sin_1 => unsqueeze_1
46
+ # squeeze => squeeze
47
+ # squeeze_2 => squeeze_2
48
+ # Graph fragment:
49
+ # %tangents_2 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24*s9, s24, 1]cuda:6" = PlaceHolder[target=tangents_2]
50
+ # %primals_8 : Tensor "i64[1, s9][s9, 1]cuda:6" = PlaceHolder[target=primals_8]
51
+ # %primals_6 : Tensor "bf16[1, 1, s79, s24][s96, s96, s24, 1]cuda:6" = PlaceHolder[target=primals_6]
52
+ # %primals_4 : Tensor "bf16[1, 1, s92, s24][s96, s96, s24, 1]cuda:6" = PlaceHolder[target=primals_4]
53
+ # %squeeze_2 : Tensor "bf16[1, s79, s24][s96, s24, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_6, 1), kwargs = {})
54
+ # %squeeze_3 : Tensor "bf16[s79, s24][s24, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze_2, 0), kwargs = {})
55
+ # %index_1 : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_3, [%primals_8]), kwargs = {})
56
+ # %unsqueeze_1 : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:6"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index_1, 1), kwargs = {})
57
+ # %mul_84 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24*s9, s24, 1]cuda:6"[num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, %unsqueeze_1), kwargs = {})
58
+ # %slice_5 : Tensor "bf16[s48, s48, s9, s24 - ((s24//2))][s24*s48*s9, s24*s9, s24, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%mul_84, 3, 0, %add_96), kwargs = {})
59
+ # %slice_6 : Tensor "bf16[s48, s48, s9, (s24//2)][s24*s48*s9, s24*s9, s24, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%mul_84, 3, %sub_72, %primals_2), kwargs = {})
60
+ # %neg_2 : Tensor "bf16[s48, s48, s9, s24 - ((s24//2))][s48*s9*Max(1, s24 - ((s24//2))), s9*Max(1, s24 - ((s24//2))), Max(1, s24 - ((s24//2))), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%slice_5,), kwargs = {})
61
+ # %full_default : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24*s9, s24, 1]cuda:6"[num_users=2] = call_function[target=torch.ops.aten.full.default](args = ([%primals_10, %primals_10, %primals_7, %primals_2], 0), kwargs = {dtype: torch.bfloat16, layout: torch.strided, device: cuda:6, pin_memory: False})
62
+ # %slice_scatter_default : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24*s9, s24, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%full_default, %neg_2, 3, %floordiv, 9223372036854775807), kwargs = {})
63
+ # %slice_scatter_default_1 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24*s9, s24, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%full_default, %slice_6, 3, 0, %floordiv), kwargs = {})
64
+ # %add_100 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24*s9, s24, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%slice_scatter_default, %slice_scatter_default_1), kwargs = {})
65
+ # %squeeze : Tensor "bf16[1, s92, s24][s96, s24, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_4, 1), kwargs = {})
66
+ # %squeeze_1 : Tensor "bf16[s92, s24][s24, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze, 0), kwargs = {})
67
+ # %index : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_1, [%primals_8]), kwargs = {})
68
+ # %unsqueeze : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:6"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index, 1), kwargs = {})
69
+ # %mul_85 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24*s9, s24, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, %unsqueeze), kwargs = {})
70
+ # %add_101 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24*s9, s24, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_100, %mul_85), kwargs = {})
71
+ # return %add_101
72
+ triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0 = async_compile.triton('triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0', '''
73
+ import triton
74
+ import triton.language as tl
75
+
76
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
77
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
78
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
79
+ triton_helpers.set_driver_to_gpu()
80
+
81
+ @triton_heuristics.pointwise(
82
+ size_hints={'x': 16777216},
83
+ filename=__file__,
84
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
85
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 6, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
86
+ min_elem_per_thread=0
87
+ )
88
+ @triton.jit
89
+ def triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, xnumel, XBLOCK : tl.constexpr):
90
+ xoffset = tl.program_id(0) * XBLOCK
91
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
92
+ xmask = xindex < xnumel
93
+ x0 = (xindex % ks0)
94
+ x3 = xindex
95
+ x1 = ((xindex // ks0) % ks1)
96
+ tmp31 = tl.load(in_ptr0 + (x3), xmask, eviction_policy='evict_last').to(tl.float32)
97
+ tmp32 = tl.load(in_ptr1 + (x1), xmask, eviction_policy='evict_last')
98
+ tmp0 = x0
99
+ tmp1 = ks0 // 2
100
+ tmp2 = tmp0 >= tmp1
101
+ tmp3 = tl.load(in_ptr0 + (x3 + (-1)*(ks0 // 2)), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
102
+ tmp4 = tl.load(in_ptr1 + (x1), tmp2 & xmask, eviction_policy='evict_last', other=0.0)
103
+ tmp5 = tl.broadcast_to(ks2, [XBLOCK])
104
+ tmp6 = tmp4 + tmp5
105
+ tmp7 = tmp4 < 0
106
+ tmp8 = tl.where(tmp7, tmp6, tmp4)
107
+ tl.device_assert(((0 <= tl.broadcast_to(tmp8, [XBLOCK])) & (tl.broadcast_to(tmp8, [XBLOCK]) < ks2)) | ~(tmp2 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp8, [XBLOCK]) < ks2")
108
+ tmp10 = tl.load(in_ptr2 + (x0 + (-1)*(ks0 // 2) + ks0*tmp8), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
109
+ tmp11 = tmp3 * tmp10
110
+ tmp12 = -tmp11
111
+ tmp13 = tl.full(tmp12.shape, 0.0, tmp12.dtype)
112
+ tmp14 = tl.where(tmp2, tmp12, tmp13)
113
+ tmp15 = 0.0
114
+ tmp16 = tl.where(tmp2, tmp14, tmp15)
115
+ tmp17 = tmp0 < tmp1
116
+ tmp18 = tl.load(in_ptr0 + (ks0 + x3 + (-1)*(ks0 // 2)), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
117
+ tmp19 = tl.load(in_ptr1 + (x1), tmp17 & xmask, eviction_policy='evict_last', other=0.0)
118
+ tmp20 = tl.broadcast_to(ks2, [XBLOCK])
119
+ tmp21 = tmp19 + tmp20
120
+ tmp22 = tmp19 < 0
121
+ tmp23 = tl.where(tmp22, tmp21, tmp19)
122
+ tl.device_assert(((0 <= tl.broadcast_to(tmp23, [XBLOCK])) & (tl.broadcast_to(tmp23, [XBLOCK]) < ks2)) | ~(tmp17 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp23, [XBLOCK]) < ks2")
123
+ tmp25 = tl.load(in_ptr2 + (ks0 + x0 + (-1)*(ks0 // 2) + ks0*tmp23), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
124
+ tmp26 = tmp18 * tmp25
125
+ tmp27 = tl.full(tmp26.shape, 0.0, tmp26.dtype)
126
+ tmp28 = tl.where(tmp17, tmp26, tmp27)
127
+ tmp29 = tl.where(tmp17, tmp28, tmp15)
128
+ tmp30 = tmp16 + tmp29
129
+ tmp33 = ks3
130
+ tmp34 = tmp32 + tmp33
131
+ tmp35 = tmp32 < 0
132
+ tmp36 = tl.where(tmp35, tmp34, tmp32)
133
+ tl.device_assert(((0 <= tmp36) & (tmp36 < ks3)) | ~(xmask), "index out of bounds: 0 <= tmp36 < ks3")
134
+ tmp38 = tl.load(in_ptr3 + (x0 + ks0*tmp36), xmask, eviction_policy='evict_last').to(tl.float32)
135
+ tmp39 = tmp31 * tmp38
136
+ tmp40 = tmp30 + tmp39
137
+ tl.store(out_ptr0 + (x3), tmp40, xmask)
138
+ ''', device_str='cuda')
139
+
140
+
141
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/eg/cegphctwzx57aawblx7563zff7jofvfpmllo4f2poi5emt43dc5t.py
142
+ # Topologically Sorted Source Nodes: [squeeze_2, sin, getitem_1, sin_1, squeeze, cos, getitem, cos_1], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.slice_backward, aten.add]
143
+ # Source node to ATen node mapping:
144
+ # cos => squeeze_1
145
+ # cos_1 => unsqueeze
146
+ # getitem => index
147
+ # getitem_1 => index_1
148
+ # sin => squeeze_3
149
+ # sin_1 => unsqueeze_1
150
+ # squeeze => squeeze
151
+ # squeeze_2 => squeeze_2
152
+ # Graph fragment:
153
+ # %tangents_1 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:6" = PlaceHolder[target=tangents_1]
154
+ # %primals_8 : Tensor "i64[1, s9][s9, 1]cuda:6" = PlaceHolder[target=primals_8]
155
+ # %primals_6 : Tensor "bf16[1, 1, s79, s24][s96, s96, s24, 1]cuda:6" = PlaceHolder[target=primals_6]
156
+ # %primals_4 : Tensor "bf16[1, 1, s92, s24][s96, s96, s24, 1]cuda:6" = PlaceHolder[target=primals_4]
157
+ # %squeeze_2 : Tensor "bf16[1, s79, s24][s96, s24, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_6, 1), kwargs = {})
158
+ # %squeeze_3 : Tensor "bf16[s79, s24][s24, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze_2, 0), kwargs = {})
159
+ # %index_1 : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_3, [%primals_8]), kwargs = {})
160
+ # %unsqueeze_1 : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:6"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index_1, 1), kwargs = {})
161
+ # %squeeze : Tensor "bf16[1, s92, s24][s96, s24, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_4, 1), kwargs = {})
162
+ # %squeeze_1 : Tensor "bf16[s92, s24][s24, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze, 0), kwargs = {})
163
+ # %index : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_1, [%primals_8]), kwargs = {})
164
+ # %unsqueeze : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:6"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index, 1), kwargs = {})
165
+ # %mul_86 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:6"[num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_1, %unsqueeze_1), kwargs = {})
166
+ # %slice_7 : Tensor "bf16[s48, s34, s9, s24 - ((s24//2))][s24*s34*s9, s24*s9, s24, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%mul_86, 3, 0, %sub_72), kwargs = {})
167
+ # %slice_8 : Tensor "bf16[s48, s34, s9, (s24//2)][s24*s34*s9, s24*s9, s24, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%mul_86, 3, %sub_72, %primals_2), kwargs = {})
168
+ # %neg_3 : Tensor "bf16[s48, s34, s9, s24 - ((s24//2))][s34*s9*Max(1, s24 - ((s24//2))), s9*Max(1, s24 - ((s24//2))), Max(1, s24 - ((s24//2))), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%slice_7,), kwargs = {})
169
+ # %full_default_2 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:6"[num_users=2] = call_function[target=torch.ops.aten.full.default](args = ([%primals_10, %primals_11, %primals_7, %primals_2], 0), kwargs = {dtype: torch.bfloat16, layout: torch.strided, device: cuda:6, pin_memory: False})
170
+ # %slice_scatter_default_2 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%full_default_2, %neg_3, 3, %floordiv, 9223372036854775807), kwargs = {})
171
+ # %slice_scatter_default_3 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%full_default_2, %slice_8, 3, 0, %floordiv), kwargs = {})
172
+ # %add_106 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%slice_scatter_default_2, %slice_scatter_default_3), kwargs = {})
173
+ # %mul_87 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_1, %unsqueeze), kwargs = {})
174
+ # %add_107 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_106, %mul_87), kwargs = {})
175
+ # return %add_107
176
+ triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1 = async_compile.triton('triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1', '''
177
+ import triton
178
+ import triton.language as tl
179
+
180
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
181
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
182
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
183
+ triton_helpers.set_driver_to_gpu()
184
+
185
+ @triton_heuristics.pointwise(
186
+ size_hints={'x': 67108864},
187
+ filename=__file__,
188
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
189
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 6, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
190
+ min_elem_per_thread=0
191
+ )
192
+ @triton.jit
193
+ def triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, xnumel, XBLOCK : tl.constexpr):
194
+ xoffset = tl.program_id(0) * XBLOCK
195
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
196
+ xmask = xindex < xnumel
197
+ x0 = (xindex % ks0)
198
+ x3 = xindex
199
+ x1 = ((xindex // ks0) % ks1)
200
+ tmp31 = tl.load(in_ptr0 + (x3), xmask, eviction_policy='evict_last').to(tl.float32)
201
+ tmp32 = tl.load(in_ptr1 + (x1), xmask, eviction_policy='evict_last')
202
+ tmp0 = x0
203
+ tmp1 = ks0 // 2
204
+ tmp2 = tmp0 >= tmp1
205
+ tmp3 = tl.load(in_ptr0 + (x3 + (-1)*(ks0 // 2)), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
206
+ tmp4 = tl.load(in_ptr1 + (x1), tmp2 & xmask, eviction_policy='evict_last', other=0.0)
207
+ tmp5 = tl.broadcast_to(ks2, [XBLOCK])
208
+ tmp6 = tmp4 + tmp5
209
+ tmp7 = tmp4 < 0
210
+ tmp8 = tl.where(tmp7, tmp6, tmp4)
211
+ tl.device_assert(((0 <= tl.broadcast_to(tmp8, [XBLOCK])) & (tl.broadcast_to(tmp8, [XBLOCK]) < ks2)) | ~(tmp2 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp8, [XBLOCK]) < ks2")
212
+ tmp10 = tl.load(in_ptr2 + (x0 + (-1)*(ks0 // 2) + ks0*tmp8), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
213
+ tmp11 = tmp3 * tmp10
214
+ tmp12 = -tmp11
215
+ tmp13 = tl.full(tmp12.shape, 0.0, tmp12.dtype)
216
+ tmp14 = tl.where(tmp2, tmp12, tmp13)
217
+ tmp15 = 0.0
218
+ tmp16 = tl.where(tmp2, tmp14, tmp15)
219
+ tmp17 = tmp0 < tmp1
220
+ tmp18 = tl.load(in_ptr0 + (ks0 + x3 + (-1)*(ks0 // 2)), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
221
+ tmp19 = tl.load(in_ptr1 + (x1), tmp17 & xmask, eviction_policy='evict_last', other=0.0)
222
+ tmp20 = tl.broadcast_to(ks2, [XBLOCK])
223
+ tmp21 = tmp19 + tmp20
224
+ tmp22 = tmp19 < 0
225
+ tmp23 = tl.where(tmp22, tmp21, tmp19)
226
+ tl.device_assert(((0 <= tl.broadcast_to(tmp23, [XBLOCK])) & (tl.broadcast_to(tmp23, [XBLOCK]) < ks2)) | ~(tmp17 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp23, [XBLOCK]) < ks2")
227
+ tmp25 = tl.load(in_ptr2 + (ks0 + x0 + (-1)*(ks0 // 2) + ks0*tmp23), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
228
+ tmp26 = tmp18 * tmp25
229
+ tmp27 = tl.full(tmp26.shape, 0.0, tmp26.dtype)
230
+ tmp28 = tl.where(tmp17, tmp26, tmp27)
231
+ tmp29 = tl.where(tmp17, tmp28, tmp15)
232
+ tmp30 = tmp16 + tmp29
233
+ tmp33 = ks3
234
+ tmp34 = tmp32 + tmp33
235
+ tmp35 = tmp32 < 0
236
+ tmp36 = tl.where(tmp35, tmp34, tmp32)
237
+ tl.device_assert(((0 <= tmp36) & (tmp36 < ks3)) | ~(xmask), "index out of bounds: 0 <= tmp36 < ks3")
238
+ tmp38 = tl.load(in_ptr3 + (x0 + ks0*tmp36), xmask, eviction_policy='evict_last').to(tl.float32)
239
+ tmp39 = tmp31 * tmp38
240
+ tmp40 = tmp30 + tmp39
241
+ tl.store(out_ptr0 + (x3), tmp40, xmask)
242
+ ''', device_str='cuda')
243
+
244
+
245
+ async_compile.wait(globals())
246
+ del async_compile
247
+
248
+ class Runner:
249
+ def __init__(self, partitions):
250
+ self.partitions = partitions
251
+
252
+ def recursively_apply_fns(self, fns):
253
+ new_callables = []
254
+ for fn, c in zip(fns, self.partitions):
255
+ new_callables.append(fn(c))
256
+ self.partitions = new_callables
257
+
258
+ def call(self, args):
259
+ primals_2, primals_7, primals_10, primals_11, primals_1, primals_3, primals_5, floordiv, add_96, primals_4, primals_6, primals_8, tangents_1, tangents_2 = args
260
+ args.clear()
261
+ s24 = primals_2
262
+ s9 = primals_7
263
+ s48 = primals_10
264
+ s34 = primals_11
265
+ s92 = primals_1
266
+ s96 = primals_3
267
+ s79 = primals_5
268
+ assert_size_stride(primals_4, (1, 1, s92, s24), (s96, s96, s24, 1))
269
+ assert_size_stride(primals_6, (1, 1, s79, s24), (s96, s96, s24, 1))
270
+ assert_size_stride(primals_8, (1, s9), (s9, 1))
271
+ assert_size_stride(tangents_1, (s48, s34, s9, s24), (s24*s34*s9, s24*s9, s24, 1))
272
+ assert_size_stride(tangents_2, (s48, s48, s9, s24), (s24*s48*s9, s24*s9, s24, 1))
273
+ with torch.cuda._DeviceGuard(6):
274
+ torch.cuda.set_device(6)
275
+ buf0 = empty_strided_cuda((s48, s48, s9, s24), (s24*s48*s9, s24*s9, s24, 1), torch.bfloat16)
276
+ # Topologically Sorted Source Nodes: [squeeze_2, sin, getitem_1, sin_1, squeeze, cos, getitem, cos_1], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.slice_backward, aten.add]
277
+ triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0_xnumel = s24*s9*s48*s48
278
+ stream6 = get_raw_stream(6)
279
+ triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0.run(tangents_2, primals_8, primals_6, primals_4, buf0, s24, s9, s79, s92, triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0_xnumel, stream=stream6)
280
+ del tangents_2
281
+ buf1 = empty_strided_cuda((s48, s34, s9, s24), (s24*s34*s9, s24*s9, s24, 1), torch.bfloat16)
282
+ # Topologically Sorted Source Nodes: [squeeze_2, sin, getitem_1, sin_1, squeeze, cos, getitem, cos_1], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.slice_backward, aten.add]
283
+ triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1_xnumel = s24*s34*s48*s9
284
+ stream6 = get_raw_stream(6)
285
+ triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1.run(tangents_1, primals_8, primals_6, primals_4, buf1, s24, s9, s79, s92, triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1_xnumel, stream=stream6)
286
+ del primals_4
287
+ del primals_6
288
+ del primals_8
289
+ del tangents_1
290
+ return (None, None, None, None, None, None, None, None, None, None, None, buf1, buf0, )
291
+
292
+ runner = Runner(partitions=[])
293
+ call = runner.call
294
+ recursively_apply_fns = runner.recursively_apply_fns
295
+
296
+
297
+ def benchmark_compiled_module(times=10, repeat=10):
298
+ from torch._dynamo.testing import rand_strided
299
+ from torch._inductor.utils import print_performance
300
+ primals_2 = 128
301
+ primals_7 = 2048
302
+ primals_10 = 8
303
+ primals_11 = 32
304
+ primals_1 = 2048
305
+ primals_3 = 5245440
306
+ primals_5 = 2048
307
+ floordiv = 64
308
+ add_96 = 64
309
+ primals_4 = rand_strided((1, 1, 2048, 128), (5245440, 5245440, 128, 1), device='cuda:6', dtype=torch.bfloat16)
310
+ primals_6 = rand_strided((1, 1, 2048, 128), (5245440, 5245440, 128, 1), device='cuda:6', dtype=torch.bfloat16)
311
+ primals_8 = rand_strided((1, 2048), (2048, 1), device='cuda:6', dtype=torch.int64)
312
+ tangents_1 = rand_strided((8, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:6', dtype=torch.bfloat16)
313
+ tangents_2 = rand_strided((8, 8, 2048, 128), (2097152, 262144, 128, 1), device='cuda:6', dtype=torch.bfloat16)
314
+ fn = lambda: call([primals_2, primals_7, primals_10, primals_11, primals_1, primals_3, primals_5, floordiv, add_96, primals_4, primals_6, primals_8, tangents_1, tangents_2])
315
+ return print_performance(fn, times=times, repeat=repeat)
316
+
317
+
318
+ if __name__ == "__main__":
319
+ from torch._inductor.wrapper_benchmark import compiled_module_main
320
+ compiled_module_main('None', benchmark_compiled_module)
SpecForge-ext/cache/compiled_kernels/aa/caabkjzbaqm7hrv3ypoalyjx45pdt7jezorxxk75d4cahg2knncu.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.reduction(
11
+ size_hints={'x': 1024, 'r0_': 16384},
12
+ reduction_hint=ReductionHint.INNER,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr1': '*i32', 'out_ptr2': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
16
+ )
17
+ @triton.jit
18
+ def triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1(in_ptr0, out_ptr1, out_ptr2, ks0, ks1, ks2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
19
+ r0_numel = 16384
20
+ rnumel = r0_numel
21
+ RBLOCK: tl.constexpr = R0_BLOCK
22
+ xoffset = tl.program_id(0) * XBLOCK
23
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
24
+ xmask = xindex < xnumel
25
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
26
+ rbase = r0_base
27
+ x0 = (xindex % ks0)
28
+ x1 = ((xindex // ks0) % 16)
29
+ x2 = xindex // ks2
30
+ _tmp36 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64)
31
+ x5 = xindex
32
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
33
+ r0_index = r0_offset + r0_base
34
+ r0_mask = r0_index < r0_numel
35
+ roffset = r0_offset
36
+ rindex = r0_index
37
+ r0_3 = (r0_index % 128)
38
+ r0_4 = r0_index // 128
39
+ tmp0 = r0_3 + 128*x0
40
+ tmp1 = ks1
41
+ tmp2 = tmp0 < tmp1
42
+ tmp3 = r0_4 + 128*x1
43
+ tmp4 = r0_3 + 128*x0
44
+ tmp5 = tmp3 >= tmp4
45
+ tmp6 = tl.load(in_ptr0 + (tl.broadcast_to(x2, [XBLOCK, R0_BLOCK])), r0_mask & tmp2 & xmask, eviction_policy='evict_last', other=0.0)
46
+ tmp7 = tmp4 < tmp6
47
+ tmp8 = tmp3 < tmp6
48
+ tmp9 = tmp7 & tmp8
49
+ tmp10 = tmp5 & tmp9
50
+ tmp11 = tl.full([1, 1], False, tl.int1)
51
+ tmp12 = tmp11 | tmp10
52
+ tmp13 = tl.full([1, 1], 2048, tl.int64)
53
+ tmp14 = tmp4 >= tmp13
54
+ tmp15 = ((r0_3 + 128*x0) % 2048)
55
+ tmp16 = tmp15 < tmp6
56
+ tmp17 = tmp14 & tmp16
57
+ tmp18 = r0_3 + ((-1)*r0_4) + ((-128)*x1) + 128*x0
58
+ tmp19 = (tmp18 % tmp13)
59
+ tmp20 = tl.full([1, 1], 0, tl.int32)
60
+ tmp21 = tmp19 != tmp20
61
+ tmp22 = (libdevice.signbit(tmp19) != 0) if (tmp19).dtype is tl.float32 else tmp19 < 0
62
+ tmp23 = (libdevice.signbit(tmp13) != 0) if (tmp13).dtype is tl.float32 else tmp13 < 0
63
+ tmp24 = tmp22 != tmp23
64
+ tmp25 = tmp21 & tmp24
65
+ tmp26 = tmp19 + tmp13
66
+ tmp27 = tl.where(tmp25, tmp26, tmp19)
67
+ tmp28 = tl.full([1, 1], 0, tl.int64)
68
+ tmp29 = tmp27 == tmp28
69
+ tmp30 = tmp17 & tmp29
70
+ tmp31 = tmp12 | tmp30
71
+ tmp32 = tl.full(tmp31.shape, False, tmp31.dtype)
72
+ tmp33 = tl.where(tmp2, tmp31, tmp32)
73
+ tmp34 = tmp33.to(tl.int64)
74
+ tmp35 = tl.broadcast_to(tmp34, [XBLOCK, R0_BLOCK])
75
+ tmp37 = _tmp36 + tmp35
76
+ _tmp36 = tl.where(r0_mask & xmask, tmp37, _tmp36)
77
+ tmp36 = tl.sum(_tmp36, 1)[:, None]
78
+ tmp38 = tl.full([1, 1], 0, tl.int64)
79
+ tmp39 = tmp36 > tmp38
80
+ tmp40 = tl.full([1, 1], 16384, tl.int64)
81
+ tmp41 = tmp36 < tmp40
82
+ tmp42 = tmp39 & tmp41
83
+ tmp43 = tmp42.to(tl.int8)
84
+ tmp44 = tmp43.to(tl.int32)
85
+ tmp45 = tmp36 == tmp40
86
+ tmp46 = tmp45.to(tl.int8)
87
+ tmp47 = tmp46.to(tl.int32)
88
+ tl.store(out_ptr1 + (x5), tmp44, xmask)
89
+ tl.store(out_ptr2 + (x5), tmp47, xmask)
SpecForge-ext/cache/compiled_kernels/af/cafe3dsuelcloemwu5jdikp7lqano5qxv7iayhtm5xgji2xvr4k6.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.reduction(
11
+ size_hints={'x': 4096, 'r0_': 32768},
12
+ reduction_hint=ReductionHint.INNER,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*i64', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
16
+ )
17
+ @triton.jit
18
+ def triton_red_fused_argmax_1(in_ptr0, out_ptr0, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
19
+ r0_numel = 32000
20
+ rnumel = r0_numel
21
+ RBLOCK: tl.constexpr = R0_BLOCK
22
+ xoffset = tl.program_id(0) * XBLOCK
23
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
24
+ xmask = xindex < xnumel
25
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
26
+ rbase = r0_base
27
+ x0 = (xindex % ks0)
28
+ x1 = xindex // ks0
29
+ _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32)
30
+ _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32)
31
+ x3 = xindex
32
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
33
+ r0_index = r0_offset + r0_base
34
+ r0_mask = r0_index < r0_numel
35
+ roffset = r0_offset
36
+ rindex = r0_index
37
+ r0_2 = r0_index
38
+ tmp0 = tl.load(in_ptr0 + (r0_2 + 32000*x0 + ks1*x1), r0_mask & xmask, eviction_policy='evict_first', other=0.0)
39
+ tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
40
+ _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index(
41
+ _tmp2, _tmp2_index, tmp1, rindex
42
+ )
43
+ _tmp2 = tl.where(r0_mask & xmask, _tmp2_next, _tmp2)
44
+ _tmp2_index = tl.where(r0_mask & xmask, _tmp2_index_next, _tmp2_index)
45
+ tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1)
46
+ tmp2 = tmp2_idx[:, None]
47
+ tl.store(out_ptr0 + (x3), tmp2, xmask)
SpecForge-ext/cache/compiled_kernels/ai/caivmpnbt7ve3qybkm6k756igdxn3ykevul35fdg4vvgknrmprqo.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.pointwise(
11
+ size_hints={'x': 16777216},
12
+ filename=__file__,
13
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 6, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
15
+ min_elem_per_thread=0
16
+ )
17
+ @triton.jit
18
+ def triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, xnumel, XBLOCK : tl.constexpr):
19
+ xoffset = tl.program_id(0) * XBLOCK
20
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
21
+ xmask = xindex < xnumel
22
+ x0 = (xindex % ks0)
23
+ x3 = xindex
24
+ x1 = ((xindex // ks0) % ks1)
25
+ tmp31 = tl.load(in_ptr0 + (x3), xmask, eviction_policy='evict_last').to(tl.float32)
26
+ tmp32 = tl.load(in_ptr1 + (x1), xmask, eviction_policy='evict_last')
27
+ tmp0 = x0
28
+ tmp1 = ks0 // 2
29
+ tmp2 = tmp0 >= tmp1
30
+ tmp3 = tl.load(in_ptr0 + (x3 + (-1)*(ks0 // 2)), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
31
+ tmp4 = tl.load(in_ptr1 + (x1), tmp2 & xmask, eviction_policy='evict_last', other=0.0)
32
+ tmp5 = tl.broadcast_to(ks2, [XBLOCK])
33
+ tmp6 = tmp4 + tmp5
34
+ tmp7 = tmp4 < 0
35
+ tmp8 = tl.where(tmp7, tmp6, tmp4)
36
+ tl.device_assert(((0 <= tl.broadcast_to(tmp8, [XBLOCK])) & (tl.broadcast_to(tmp8, [XBLOCK]) < ks2)) | ~(tmp2 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp8, [XBLOCK]) < ks2")
37
+ tmp10 = tl.load(in_ptr2 + (x0 + (-1)*(ks0 // 2) + ks0*tmp8), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
38
+ tmp11 = tmp3 * tmp10
39
+ tmp12 = -tmp11
40
+ tmp13 = tl.full(tmp12.shape, 0.0, tmp12.dtype)
41
+ tmp14 = tl.where(tmp2, tmp12, tmp13)
42
+ tmp15 = 0.0
43
+ tmp16 = tl.where(tmp2, tmp14, tmp15)
44
+ tmp17 = tmp0 < tmp1
45
+ tmp18 = tl.load(in_ptr0 + (ks0 + x3 + (-1)*(ks0 // 2)), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
46
+ tmp19 = tl.load(in_ptr1 + (x1), tmp17 & xmask, eviction_policy='evict_last', other=0.0)
47
+ tmp20 = tl.broadcast_to(ks2, [XBLOCK])
48
+ tmp21 = tmp19 + tmp20
49
+ tmp22 = tmp19 < 0
50
+ tmp23 = tl.where(tmp22, tmp21, tmp19)
51
+ tl.device_assert(((0 <= tl.broadcast_to(tmp23, [XBLOCK])) & (tl.broadcast_to(tmp23, [XBLOCK]) < ks2)) | ~(tmp17 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp23, [XBLOCK]) < ks2")
52
+ tmp25 = tl.load(in_ptr2 + (ks0 + x0 + (-1)*(ks0 // 2) + ks0*tmp23), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
53
+ tmp26 = tmp18 * tmp25
54
+ tmp27 = tl.full(tmp26.shape, 0.0, tmp26.dtype)
55
+ tmp28 = tl.where(tmp17, tmp26, tmp27)
56
+ tmp29 = tl.where(tmp17, tmp28, tmp15)
57
+ tmp30 = tmp16 + tmp29
58
+ tmp33 = ks3
59
+ tmp34 = tmp32 + tmp33
60
+ tmp35 = tmp32 < 0
61
+ tmp36 = tl.where(tmp35, tmp34, tmp32)
62
+ tl.device_assert(((0 <= tmp36) & (tmp36 < ks3)) | ~(xmask), "index out of bounds: 0 <= tmp36 < ks3")
63
+ tmp38 = tl.load(in_ptr3 + (x0 + ks0*tmp36), xmask, eviction_policy='evict_last').to(tl.float32)
64
+ tmp39 = tmp31 * tmp38
65
+ tmp40 = tmp30 + tmp39
66
+ tl.store(out_ptr0 + (x3), tmp40, xmask)
SpecForge-ext/cache/compiled_kernels/ai/f2f38be4dfdf6b1c14c068f88a04203cd9a67c3fc07629f341d6212e60d2f52e.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 512, "num_warps": 8, "num_stages": 1, "configs_hash": "3ca5c3e34d35093f3c9ab2829a9faeebad5e61c4ca13d5ed6053d7b71ce60d5a", "found_by_coordesc": false, "time_taken_ms": 53, "triton_cache_hash": "UQSFYICF6CFQWZOBHCGZ7JZ457GHWVO6RMPN5ABNWOATFMKI6GQA"}
SpecForge-ext/cache/compiled_kernels/al/25feb68bb70a2d653884ed092be99a324d74e7c4fa2b0800c70b0c5cede23a82.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 512, "num_warps": 8, "num_stages": 1, "configs_hash": "3ca5c3e34d35093f3c9ab2829a9faeebad5e61c4ca13d5ed6053d7b71ce60d5a", "found_by_coordesc": false, "time_taken_ms": 50, "triton_cache_hash": "NFABHOURJ57C2IKXWDMS2VHZ76PCVKJVD7V6CBWJDLMT5TQE5GFA"}
SpecForge-ext/cache/compiled_kernels/al/cal2r4tfyw6gic3ggqyud3nufnajx6xau2koieoitx6zg4wsiozm.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.pointwise(
11
+ size_hints={'x': 16777216},
12
+ filename=__file__,
13
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'ks4': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
15
+ min_elem_per_thread=0
16
+ )
17
+ @triton.jit
18
+ def triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, ks4, xnumel, XBLOCK : tl.constexpr):
19
+ xoffset = tl.program_id(0) * XBLOCK
20
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
21
+ xmask = xindex < xnumel
22
+ x4 = xindex
23
+ x2 = ((xindex // ks0) % ks1)
24
+ x0 = (xindex % ks3)
25
+ x5 = xindex // ks3
26
+ tmp0 = tl.load(in_ptr0 + (x4), xmask, eviction_policy='evict_last').to(tl.float32)
27
+ tmp1 = tl.load(in_ptr1 + (x2), xmask, eviction_policy='evict_last')
28
+ tmp2 = ks2
29
+ tmp3 = tmp1 + tmp2
30
+ tmp4 = tmp1 < 0
31
+ tmp5 = tl.where(tmp4, tmp3, tmp1)
32
+ tl.device_assert(((0 <= tmp5) & (tmp5 < ks2)) | ~(xmask), "index out of bounds: 0 <= tmp5 < ks2")
33
+ tmp7 = tl.load(in_ptr2 + (x0 + ks3*tmp5), xmask, eviction_policy='evict_last').to(tl.float32)
34
+ tmp8 = tmp0 * tmp7
35
+ tmp9 = x0
36
+ tmp10 = tl.full([1], 0, tl.int64)
37
+ tmp11 = tmp9 >= tmp10
38
+ tmp12 = ks3 + (-1)*(ks3 // 2)
39
+ tmp13 = tmp9 < tmp12
40
+ tmp14 = tl.load(in_ptr0 + (ks3*x5 + (ks3 // 2) + (x0)), tmp13 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
41
+ tmp15 = -tmp14
42
+ tmp16 = tl.full(tmp15.shape, 0.0, tmp15.dtype)
43
+ tmp17 = tl.where(tmp13, tmp15, tmp16)
44
+ tmp18 = tmp9 >= tmp12
45
+ tmp19 = ks3
46
+ tmp20 = tmp9 < tmp19
47
+ tmp21 = tl.load(in_ptr0 + (ks3*x5 + (x0 + ((-1)*ks3) + (ks3 // 2))), tmp18 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
48
+ tmp22 = tl.where(tmp13, tmp17, tmp21)
49
+ tmp23 = ks4
50
+ tmp24 = tmp1 + tmp23
51
+ tmp25 = tl.where(tmp4, tmp24, tmp1)
52
+ tl.device_assert(((0 <= tmp25) & (tmp25 < ks4)) | ~(xmask), "index out of bounds: 0 <= tmp25 < ks4")
53
+ tmp27 = tl.load(in_ptr3 + (x0 + ks3*tmp25), xmask, eviction_policy='evict_last').to(tl.float32)
54
+ tmp28 = tmp22 * tmp27
55
+ tmp29 = tmp8 + tmp28
56
+ tl.store(out_ptr0 + (x4), tmp29, xmask)
SpecForge-ext/cache/compiled_kernels/aq/caqqpjwqelw7hv6k6nwpxjuod3tfnwg62cypxwyuozfme2ykuybp.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AOT ID: ['4_forward']
2
+ from ctypes import c_void_p, c_long, c_int
3
+ import torch
4
+ import math
5
+ import random
6
+ import os
7
+ import tempfile
8
+ from math import inf, nan
9
+ from cmath import nanj
10
+ from torch._inductor.hooks import run_intermediate_hooks
11
+ from torch._inductor.utils import maybe_profile
12
+ from torch._inductor.codegen.memory_planning import _align as align
13
+ from torch import device, empty_strided
14
+ from torch._inductor.async_compile import AsyncCompile
15
+ from torch._inductor.select_algorithm import extern_kernels
16
+ import triton
17
+ import triton.language as tl
18
+ from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
19
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
20
+
21
+ aten = torch.ops.aten
22
+ inductor_ops = torch.ops.inductor
23
+ _quantized = torch.ops._quantized
24
+ assert_size_stride = torch._C._dynamo.guards.assert_size_stride
25
+ assert_alignment = torch._C._dynamo.guards.assert_alignment
26
+ empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
27
+ empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
28
+ empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
29
+ empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
30
+ empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
31
+ reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
32
+ alloc_from_pool = torch.ops.inductor._alloc_from_pool
33
+ async_compile = AsyncCompile()
34
+ empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
35
+
36
+
37
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3k/c3kdupo6eufhy2marzoeoddgc3okqj6m3aii3f42onl4ag77vf6u.py
38
+ # Topologically Sorted Source Nodes: [squeeze, cos, squeeze_2, sin, getitem, cos_1, getitem_1, sin_1, mul, x1, x2, neg, cat, mul_1, q_embed], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.cat, aten.add]
39
+ # Source node to ATen node mapping:
40
+ # cat => cat
41
+ # cos => squeeze_1
42
+ # cos_1 => unsqueeze
43
+ # getitem => index
44
+ # getitem_1 => index_1
45
+ # mul => mul_24
46
+ # mul_1 => mul_45
47
+ # neg => neg
48
+ # q_embed => add_54
49
+ # sin => squeeze_3
50
+ # sin_1 => unsqueeze_1
51
+ # squeeze => squeeze
52
+ # squeeze_2 => squeeze_2
53
+ # x1 => slice_1
54
+ # x2 => slice_2
55
+ # Graph fragment:
56
+ # %primals_12 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24, s24*s34, 1]cuda:5" = PlaceHolder[target=primals_12]
57
+ # %primals_8 : Tensor "i64[1, s9][s9, 1]cuda:5" = PlaceHolder[target=primals_8]
58
+ # %primals_4 : Tensor "bf16[1, 1, s92, s24][s96, s96, s24, 1]cuda:5" = PlaceHolder[target=primals_4]
59
+ # %primals_6 : Tensor "bf16[1, 1, s79, s24][s96, s96, s24, 1]cuda:5" = PlaceHolder[target=primals_6]
60
+ # %squeeze : Tensor "bf16[1, s92, s24][s96, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_4, 1), kwargs = {})
61
+ # %squeeze_1 : Tensor "bf16[s92, s24][s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze, 0), kwargs = {})
62
+ # %squeeze_2 : Tensor "bf16[1, s79, s24][s96, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_6, 1), kwargs = {})
63
+ # %squeeze_3 : Tensor "bf16[s79, s24][s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze_2, 0), kwargs = {})
64
+ # %index : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_1, [%primals_8]), kwargs = {})
65
+ # %unsqueeze : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:5"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index, 1), kwargs = {})
66
+ # %index_1 : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_3, [%primals_8]), kwargs = {})
67
+ # %unsqueeze_1 : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:5"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index_1, 1), kwargs = {})
68
+ # %mul_24 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24, s24*s34, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%primals_12, %unsqueeze), kwargs = {})
69
+ # %slice_1 : Tensor "bf16[s48, s34, s9, (s24//2)][s24*s34*s9, s24, s24*s34, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%primals_12, 3, 0, %floordiv), kwargs = {})
70
+ # %slice_2 : Tensor "bf16[s48, s34, s9, s24 - ((s24//2))][s24*s34*s9, s24, s24*s34, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%primals_12, 3, %floordiv, 9223372036854775807), kwargs = {})
71
+ # %neg : Tensor "bf16[s48, s34, s9, s24 - ((s24//2))][s34*s9*Max(1, s24 - ((s24//2))), Max(1, s24 - ((s24//2))), s34*Max(1, s24 - ((s24//2))), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%slice_2,), kwargs = {})
72
+ # %cat : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg, %slice_1], -1), kwargs = {})
73
+ # %mul_45 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat, %unsqueeze_1), kwargs = {})
74
+ # %add_54 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24, s24*s34, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_24, %mul_45), kwargs = {})
75
+ # return %add_54
76
+ triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0 = async_compile.triton('triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0', '''
77
+ import triton
78
+ import triton.language as tl
79
+
80
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
81
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
82
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
83
+ triton_helpers.set_driver_to_gpu()
84
+
85
+ @triton_heuristics.pointwise(
86
+ size_hints={'x': 67108864},
87
+ filename=__file__,
88
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'ks4': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
89
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
90
+ min_elem_per_thread=0
91
+ )
92
+ @triton.jit
93
+ def triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, ks4, xnumel, XBLOCK : tl.constexpr):
94
+ xoffset = tl.program_id(0) * XBLOCK
95
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
96
+ xmask = xindex < xnumel
97
+ x4 = xindex
98
+ x2 = ((xindex // ks0) % ks1)
99
+ x0 = (xindex % ks3)
100
+ x5 = xindex // ks3
101
+ tmp0 = tl.load(in_ptr0 + (x4), xmask, eviction_policy='evict_last').to(tl.float32)
102
+ tmp1 = tl.load(in_ptr1 + (x2), xmask, eviction_policy='evict_last')
103
+ tmp2 = ks2
104
+ tmp3 = tmp1 + tmp2
105
+ tmp4 = tmp1 < 0
106
+ tmp5 = tl.where(tmp4, tmp3, tmp1)
107
+ tl.device_assert(((0 <= tmp5) & (tmp5 < ks2)) | ~(xmask), "index out of bounds: 0 <= tmp5 < ks2")
108
+ tmp7 = tl.load(in_ptr2 + (x0 + ks3*tmp5), xmask, eviction_policy='evict_last').to(tl.float32)
109
+ tmp8 = tmp0 * tmp7
110
+ tmp9 = x0
111
+ tmp10 = tl.full([1], 0, tl.int64)
112
+ tmp11 = tmp9 >= tmp10
113
+ tmp12 = ks3 + (-1)*(ks3 // 2)
114
+ tmp13 = tmp9 < tmp12
115
+ tmp14 = tl.load(in_ptr0 + (ks3*x5 + (ks3 // 2) + (x0)), tmp13 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
116
+ tmp15 = -tmp14
117
+ tmp16 = tl.full(tmp15.shape, 0.0, tmp15.dtype)
118
+ tmp17 = tl.where(tmp13, tmp15, tmp16)
119
+ tmp18 = tmp9 >= tmp12
120
+ tmp19 = ks3
121
+ tmp20 = tmp9 < tmp19
122
+ tmp21 = tl.load(in_ptr0 + (ks3*x5 + (x0 + ((-1)*ks3) + (ks3 // 2))), tmp18 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
123
+ tmp22 = tl.where(tmp13, tmp17, tmp21)
124
+ tmp23 = ks4
125
+ tmp24 = tmp1 + tmp23
126
+ tmp25 = tl.where(tmp4, tmp24, tmp1)
127
+ tl.device_assert(((0 <= tmp25) & (tmp25 < ks4)) | ~(xmask), "index out of bounds: 0 <= tmp25 < ks4")
128
+ tmp27 = tl.load(in_ptr3 + (x0 + ks3*tmp25), xmask, eviction_policy='evict_last').to(tl.float32)
129
+ tmp28 = tmp22 * tmp27
130
+ tmp29 = tmp8 + tmp28
131
+ tl.store(out_ptr0 + (x4), tmp29, xmask)
132
+ ''', device_str='cuda')
133
+
134
+
135
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/n2/cn24lurjdnbidkarxbtzqpcvotiay3hsbqwsbqw73gg63elg6tak.py
136
+ # Topologically Sorted Source Nodes: [squeeze, cos, squeeze_2, sin, getitem, cos_1, getitem_1, sin_1, mul_2, x1_1, x2_1, neg_1, cat_1, mul_3, k_embed], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.cat, aten.add]
137
+ # Source node to ATen node mapping:
138
+ # cat_1 => cat_1
139
+ # cos => squeeze_1
140
+ # cos_1 => unsqueeze
141
+ # getitem => index
142
+ # getitem_1 => index_1
143
+ # k_embed => add_90
144
+ # mul_2 => mul_54
145
+ # mul_3 => mul_75
146
+ # neg_1 => neg_1
147
+ # sin => squeeze_3
148
+ # sin_1 => unsqueeze_1
149
+ # squeeze => squeeze
150
+ # squeeze_2 => squeeze_2
151
+ # x1_1 => slice_3
152
+ # x2_1 => slice_4
153
+ # Graph fragment:
154
+ # %primals_13 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24, s24*s48, 1]cuda:5" = PlaceHolder[target=primals_13]
155
+ # %primals_8 : Tensor "i64[1, s9][s9, 1]cuda:5" = PlaceHolder[target=primals_8]
156
+ # %primals_4 : Tensor "bf16[1, 1, s92, s24][s96, s96, s24, 1]cuda:5" = PlaceHolder[target=primals_4]
157
+ # %primals_6 : Tensor "bf16[1, 1, s79, s24][s96, s96, s24, 1]cuda:5" = PlaceHolder[target=primals_6]
158
+ # %squeeze : Tensor "bf16[1, s92, s24][s96, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_4, 1), kwargs = {})
159
+ # %squeeze_1 : Tensor "bf16[s92, s24][s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze, 0), kwargs = {})
160
+ # %squeeze_2 : Tensor "bf16[1, s79, s24][s96, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_6, 1), kwargs = {})
161
+ # %squeeze_3 : Tensor "bf16[s79, s24][s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze_2, 0), kwargs = {})
162
+ # %index : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_1, [%primals_8]), kwargs = {})
163
+ # %unsqueeze : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:5"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index, 1), kwargs = {})
164
+ # %index_1 : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_3, [%primals_8]), kwargs = {})
165
+ # %unsqueeze_1 : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:5"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index_1, 1), kwargs = {})
166
+ # %mul_54 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24, s24*s48, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%primals_13, %unsqueeze), kwargs = {})
167
+ # %slice_3 : Tensor "bf16[s48, s48, s9, (s24//2)][s24*s48*s9, s24, s24*s48, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%primals_13, 3, 0, %floordiv), kwargs = {})
168
+ # %slice_4 : Tensor "bf16[s48, s48, s9, s24 - ((s24//2))][s24*s48*s9, s24, s24*s48, 1]cuda:5"[num_users=2] = call_function[target=torch.ops.aten.slice.Tensor](args = (%primals_13, 3, %floordiv, 9223372036854775807), kwargs = {})
169
+ # %neg_1 : Tensor "bf16[s48, s48, s9, s24 - ((s24//2))][s48*s9*Max(1, s24 - ((s24//2))), Max(1, s24 - ((s24//2))), s48*Max(1, s24 - ((s24//2))), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%slice_4,), kwargs = {})
170
+ # %cat_1 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24*s9, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg_1, %slice_3], -1), kwargs = {})
171
+ # %mul_75 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24*s9, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_1, %unsqueeze_1), kwargs = {})
172
+ # %add_90 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24, s24*s48, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_54, %mul_75), kwargs = {})
173
+ # return %add_90
174
+ triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1 = async_compile.triton('triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1', '''
175
+ import triton
176
+ import triton.language as tl
177
+
178
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
179
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
180
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
181
+ triton_helpers.set_driver_to_gpu()
182
+
183
+ @triton_heuristics.pointwise(
184
+ size_hints={'x': 16777216},
185
+ filename=__file__,
186
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'ks4': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
187
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
188
+ min_elem_per_thread=0
189
+ )
190
+ @triton.jit
191
+ def triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, ks4, xnumel, XBLOCK : tl.constexpr):
192
+ xoffset = tl.program_id(0) * XBLOCK
193
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
194
+ xmask = xindex < xnumel
195
+ x4 = xindex
196
+ x2 = ((xindex // ks0) % ks1)
197
+ x0 = (xindex % ks3)
198
+ x5 = xindex // ks3
199
+ tmp0 = tl.load(in_ptr0 + (x4), xmask, eviction_policy='evict_last').to(tl.float32)
200
+ tmp1 = tl.load(in_ptr1 + (x2), xmask, eviction_policy='evict_last')
201
+ tmp2 = ks2
202
+ tmp3 = tmp1 + tmp2
203
+ tmp4 = tmp1 < 0
204
+ tmp5 = tl.where(tmp4, tmp3, tmp1)
205
+ tl.device_assert(((0 <= tmp5) & (tmp5 < ks2)) | ~(xmask), "index out of bounds: 0 <= tmp5 < ks2")
206
+ tmp7 = tl.load(in_ptr2 + (x0 + ks3*tmp5), xmask, eviction_policy='evict_last').to(tl.float32)
207
+ tmp8 = tmp0 * tmp7
208
+ tmp9 = x0
209
+ tmp10 = tl.full([1], 0, tl.int64)
210
+ tmp11 = tmp9 >= tmp10
211
+ tmp12 = ks3 + (-1)*(ks3 // 2)
212
+ tmp13 = tmp9 < tmp12
213
+ tmp14 = tl.load(in_ptr0 + (ks3*x5 + (ks3 // 2) + (x0)), tmp13 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
214
+ tmp15 = -tmp14
215
+ tmp16 = tl.full(tmp15.shape, 0.0, tmp15.dtype)
216
+ tmp17 = tl.where(tmp13, tmp15, tmp16)
217
+ tmp18 = tmp9 >= tmp12
218
+ tmp19 = ks3
219
+ tmp20 = tmp9 < tmp19
220
+ tmp21 = tl.load(in_ptr0 + (ks3*x5 + (x0 + ((-1)*ks3) + (ks3 // 2))), tmp18 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
221
+ tmp22 = tl.where(tmp13, tmp17, tmp21)
222
+ tmp23 = ks4
223
+ tmp24 = tmp1 + tmp23
224
+ tmp25 = tl.where(tmp4, tmp24, tmp1)
225
+ tl.device_assert(((0 <= tmp25) & (tmp25 < ks4)) | ~(xmask), "index out of bounds: 0 <= tmp25 < ks4")
226
+ tmp27 = tl.load(in_ptr3 + (x0 + ks3*tmp25), xmask, eviction_policy='evict_last').to(tl.float32)
227
+ tmp28 = tmp22 * tmp27
228
+ tmp29 = tmp8 + tmp28
229
+ tl.store(out_ptr0 + (x4), tmp29, xmask)
230
+ ''', device_str='cuda')
231
+
232
+
233
+ async_compile.wait(globals())
234
+ del async_compile
235
+
236
+ class Runner:
237
+ def __init__(self, partitions):
238
+ self.partitions = partitions
239
+
240
+ def recursively_apply_fns(self, fns):
241
+ new_callables = []
242
+ for fn, c in zip(fns, self.partitions):
243
+ new_callables.append(fn(c))
244
+ self.partitions = new_callables
245
+
246
+ def call(self, args):
247
+ primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13 = args
248
+ args.clear()
249
+ s92 = primals_1
250
+ s24 = primals_2
251
+ s96 = primals_3
252
+ s79 = primals_5
253
+ s9 = primals_7
254
+ s38 = primals_9
255
+ s48 = primals_10
256
+ s34 = primals_11
257
+ assert_size_stride(primals_4, (1, 1, s92, s24), (s96, s96, s24, 1))
258
+ assert_size_stride(primals_6, (1, 1, s79, s24), (s96, s96, s24, 1))
259
+ assert_size_stride(primals_8, (1, s9), (s9, 1))
260
+ assert_size_stride(primals_12, (s48, s34, s9, s24), (s24*s34*s9, s24, s24*s34, 1))
261
+ assert_size_stride(primals_13, (s48, s48, s9, s24), (s24*s48*s9, s24, s24*s48, 1))
262
+ with torch.cuda._DeviceGuard(5):
263
+ torch.cuda.set_device(5)
264
+ ps0 = s24*s34
265
+ buf0 = empty_strided_cuda((s48, s34, s9, s24), (s24*s34*s9, s24, s24*s34, 1), torch.bfloat16)
266
+ # Topologically Sorted Source Nodes: [squeeze, cos, squeeze_2, sin, getitem, cos_1, getitem_1, sin_1, mul, x1, x2, neg, cat, mul_1, q_embed], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.cat, aten.add]
267
+ triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0_xnumel = s24*s34*s48*s9
268
+ stream5 = get_raw_stream(5)
269
+ triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0.run(primals_12, primals_8, primals_4, primals_6, buf0, ps0, s9, s92, s24, s79, triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0_xnumel, stream=stream5)
270
+ del primals_12
271
+ ps1 = s24*s48
272
+ buf1 = empty_strided_cuda((s48, s48, s9, s24), (s24*s48*s9, s24, s24*s48, 1), torch.bfloat16)
273
+ # Topologically Sorted Source Nodes: [squeeze, cos, squeeze_2, sin, getitem, cos_1, getitem_1, sin_1, mul_2, x1_1, x2_1, neg_1, cat_1, mul_3, k_embed], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.cat, aten.add]
274
+ triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1_xnumel = s24*s9*s48*s48
275
+ stream5 = get_raw_stream(5)
276
+ triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1.run(primals_13, primals_8, primals_4, primals_6, buf1, ps1, s9, s92, s24, s79, triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1_xnumel, stream=stream5)
277
+ del primals_13
278
+ return (buf0, buf1, primals_4, primals_6, primals_8, s24, s9, s48, s34, s92, s96, s79, s24 // 2, s24 + (-1)*(s24 // 2), )
279
+
280
+ runner = Runner(partitions=[])
281
+ call = runner.call
282
+ recursively_apply_fns = runner.recursively_apply_fns
283
+
284
+
285
+ def benchmark_compiled_module(times=10, repeat=10):
286
+ from torch._dynamo.testing import rand_strided
287
+ from torch._inductor.utils import print_performance
288
+ primals_1 = 2048
289
+ primals_2 = 128
290
+ primals_3 = 5245440
291
+ primals_4 = rand_strided((1, 1, 2048, 128), (5245440, 5245440, 128, 1), device='cuda:5', dtype=torch.bfloat16)
292
+ primals_5 = 2048
293
+ primals_6 = rand_strided((1, 1, 2048, 128), (5245440, 5245440, 128, 1), device='cuda:5', dtype=torch.bfloat16)
294
+ primals_7 = 2048
295
+ primals_8 = rand_strided((1, 2048), (2048, 1), device='cuda:5', dtype=torch.int64)
296
+ primals_9 = 1
297
+ primals_10 = 8
298
+ primals_11 = 32
299
+ primals_12 = rand_strided((8, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:5', dtype=torch.bfloat16)
300
+ primals_13 = rand_strided((8, 8, 2048, 128), (2097152, 128, 1024, 1), device='cuda:5', dtype=torch.bfloat16)
301
+ fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13])
302
+ return print_performance(fn, times=times, repeat=repeat)
303
+
304
+
305
+ if __name__ == "__main__":
306
+ from torch._inductor.wrapper_benchmark import compiled_module_main
307
+ compiled_module_main('None', benchmark_compiled_module)
SpecForge-ext/cache/compiled_kernels/aq/caqvrlb25w5an4txp3dstxcj6tqlcc4mprakf75e5sbtbuzd254g.py ADDED
@@ -0,0 +1,711 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AOT ID: ['13_forward']
2
+ from ctypes import c_void_p, c_long, c_int
3
+ import torch
4
+ import math
5
+ import random
6
+ import os
7
+ import tempfile
8
+ from math import inf, nan
9
+ from cmath import nanj
10
+ from torch._inductor.hooks import run_intermediate_hooks
11
+ from torch._inductor.utils import maybe_profile
12
+ from torch._inductor.codegen.memory_planning import _align as align
13
+ from torch import device, empty_strided
14
+ from torch._inductor.async_compile import AsyncCompile
15
+ from torch._inductor.select_algorithm import extern_kernels
16
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
17
+ import triton
18
+ import triton.language as tl
19
+ from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
20
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
21
+
22
+ aten = torch.ops.aten
23
+ inductor_ops = torch.ops.inductor
24
+ _quantized = torch.ops._quantized
25
+ assert_size_stride = torch._C._dynamo.guards.assert_size_stride
26
+ assert_alignment = torch._C._dynamo.guards.assert_alignment
27
+ empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
28
+ empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
29
+ empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
30
+ empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
31
+ empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
32
+ reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
33
+ alloc_from_pool = torch.ops.inductor._alloc_from_pool
34
+ async_compile = AsyncCompile()
35
+ empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
36
+
37
+
38
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/u4/cu4la2snj6taof6hjdgfl2ludclb5rxnhhncr47hr5tawo3djlhk.py
39
+ # Topologically Sorted Source Nodes: [flex_attention], Original ATen: []
40
+ # Source node to ATen node mapping:
41
+ # flex_attention => flex_attention
42
+ # Graph fragment:
43
+ # %primals_2 : Tensor "bf16[2, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:7" = PlaceHolder[target=primals_2]
44
+ # %primals_4 : Tensor "bf16[2, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:7" = PlaceHolder[target=primals_4]
45
+ # %primals_6 : Tensor "bf16[2, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:7" = PlaceHolder[target=primals_6]
46
+ # %getitem_1 : Tensor "f32[2, 32, s37][32*s37, s37, 1]cuda:7" = PlaceHolder[target=getitem_1]
47
+ # %buf1 : Tensor "f32[2, 32, s37][32*s37, s37, 1]cuda:7" = PlaceHolder[target=buf1]
48
+ # %primals_13 : Tensor "i32[2, 1, s99][s99, s99, 1]cuda:7" = PlaceHolder[target=primals_13]
49
+ # %primals_9 : Tensor "i32[2, 1, s22, s72][s22*s72, s22*s72, s72, 1]cuda:7" = PlaceHolder[target=primals_9]
50
+ # %primals_17 : Tensor "i32[2, 1, s94][s94, s94, 1]cuda:7" = PlaceHolder[target=primals_17]
51
+ # %primals_20 : Tensor "i32[2, 1, s28, s4][s28*s4, s28*s4, s4, 1]cuda:7" = PlaceHolder[target=primals_20]
52
+ # %primals_14 : Tensor "i64[2][1]cuda:7" = PlaceHolder[target=primals_14]
53
+ # %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%primals_2, %primals_4, %primals_6, %sdpa_score0, (%primals_10, %primals_11, %primals_13, %primals_9, %primals_17, %primals_20, %primals_22, %primals_25, %primals_27, %primals_30, 128, 128, %sdpa_mask0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_14, %primals_15)), kwargs = {})
54
+ # return %getitem
55
+ triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', '''
56
+ import triton
57
+ import triton.language as tl
58
+
59
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
60
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
61
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
62
+
63
+ @triton_heuristics.template(
64
+
65
+ num_stages=3,
66
+ num_warps=8,
67
+ triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'in_ptr9': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]},
68
+ inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
69
+
70
+ )
71
+ @triton.jit
72
+ def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5):
73
+ PRESCALE_QK : tl.constexpr = False
74
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
75
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
76
+ WRITE_DQ : tl.constexpr = True
77
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
78
+ OUTPUT_MAX : tl.constexpr = False
79
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
80
+ IS_DIVISIBLE : tl.constexpr = False
81
+ SM_SCALE : tl.constexpr = 0.08838834764831843
82
+ GQA_SHARED_HEADS : tl.constexpr = 4
83
+ HAS_FULL_BLOCKS : tl.constexpr = True
84
+ QK_HEAD_DIM : tl.constexpr = 128
85
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
86
+ V_HEAD_DIM : tl.constexpr = 128
87
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
88
+ SAFE_HEAD_DIM : tl.constexpr = True
89
+ USE_TMA : tl.constexpr = False
90
+ BLOCK_M : tl.constexpr = 128
91
+ BLOCK_N : tl.constexpr = 64
92
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
93
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
94
+ INDEX_DTYPE : tl.constexpr = tl.int32
95
+ Q = arg_Q
96
+ K = arg_K
97
+ V = arg_V
98
+ LSE = arg_LSE
99
+ MAX = arg_MAX
100
+ KV_NUM_BLKS = arg_KV_NUM_BLKS
101
+ KV_IDX = arg_KV_IDX
102
+ FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
103
+ FULL_KV_IDX = arg_FULL_KV_IDX
104
+
105
+ # Sub notation for this kernel:
106
+ #
107
+ # Q: Query, K: Key, V: Value
108
+ # M: Number of queries, N: Number of keys/values, D: Model dimension
109
+ # QK_HEAD_DIM: The dimension of the query and key embeddings
110
+ # V_HEAD_DIM: The dimension of the value embeddings
111
+ # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head
112
+ # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
113
+ #
114
+ # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
115
+ # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
116
+ # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
117
+ # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
118
+ # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
119
+ #
120
+ # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad
121
+ #
122
+ # (Modifiable) Performance tuning options
123
+ # BLOCK_M: The thread block size across the seqlen dim of Q.
124
+ # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block.
125
+
126
+ # The below are kernel options that can be applied for certain score_mods,
127
+ # or involve a numerics vs. perf tradeoff
128
+ # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
129
+ # about 20% more numerical error, but slightly faster.
130
+ # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row
131
+ # is not masked out? If so, we can skip an extra safety check
132
+ # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are
133
+ # contiguous? If so, we don't need to do an indirect jump for every block
134
+
135
+ tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0)
136
+ tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0)
137
+
138
+ # Define strides of inputs
139
+ stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1
140
+ stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128*ks1, 128, 1
141
+ stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128*ks1, 128, 1
142
+
143
+ ZQ = 2
144
+ HQ = 32
145
+ Q_LEN = ks0
146
+ ZKV = 2
147
+ KV_LEN = ks1
148
+
149
+ MATMUL_PRECISION = Q.dtype.element_ty
150
+
151
+ q_start = tl.program_id(0).to(INDEX_DTYPE)
152
+ off_zq = tl.program_id(1).to(INDEX_DTYPE)
153
+ off_hq = tl.program_id(2).to(INDEX_DTYPE)
154
+
155
+ # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq.
156
+ # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0.
157
+ off_zkv = off_zq % ZKV
158
+ off_hkv = off_hq // GQA_SHARED_HEADS
159
+ off_g = off_hq % GQA_SHARED_HEADS
160
+
161
+ q_offset = off_zq * stride_qz + off_hq * stride_qh
162
+ k_offset = off_zkv * stride_kz + off_hkv * stride_kh
163
+ v_offset = off_zkv * stride_vz + off_hkv * stride_vh
164
+
165
+ Q = Q + q_offset
166
+ K = K + k_offset
167
+ V = V + v_offset
168
+
169
+ # Setting up the TMA descriptors for Q, K, V
170
+ desc_q = None
171
+ desc_k = None
172
+ desc_v = None
173
+
174
+ SPARSE_Z = 2
175
+ SPARSE_HQ = 1
176
+
177
+ sparse_idx_z = off_zq % SPARSE_Z
178
+ sparse_idx_hq = off_hq % SPARSE_HQ
179
+
180
+ SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M)
181
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
182
+
183
+ stride_kv_num_blks_h = ks2
184
+ stride_kv_idx_h = ks3*ks4
185
+ stride_kv_idx_m = ks4
186
+
187
+ # initialize pointer to m and l
188
+ m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
189
+ l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
190
+ acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
191
+
192
+ offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
193
+
194
+ # KV_IDX and KV_NUM_BLKS are always contiguous.
195
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq
196
+ sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE
197
+ sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950
198
+ offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
199
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
200
+ q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
201
+
202
+ # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
203
+ # We don't know anything "special" about these blocks, so we need to apply
204
+ # both score_mod and mask_mod to it
205
+ kv_indices = KV_IDX + sparse_kv_idx_offset
206
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
207
+ kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
208
+ block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
209
+
210
+
211
+ # K and V pointers will be passed directly to forward_inner
212
+
213
+ offs_n = kv_start + tl.arange(0, BLOCK_N)
214
+
215
+
216
+ acc, l_i, m_i = forward_inner(
217
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5,
218
+ q, K, V,
219
+ desc_k, desc_v, Q_LEN, KV_LEN,
220
+ acc, l_i, m_i,
221
+ off_zq, off_hq, offs_m[:, None], offs_n[None, :],
222
+ kv_start,
223
+ kv_indices, kv_num_blocks,
224
+ 0, block_n_end,
225
+ MATMUL_PRECISION,
226
+ stride_kk, stride_kn, stride_vn, stride_vk,
227
+ IS_FULL_BLOCKS=False,
228
+ )
229
+
230
+ # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
231
+ # We know these blocks are guaranteed to be "full", so we don't need to
232
+ # apply mask_mod to them - only score_mod
233
+ if HAS_FULL_BLOCKS:
234
+ # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
235
+ kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
236
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
237
+ kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
238
+ block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
239
+ # K and V pointers will be passed directly to forward_inner
240
+ offs_n = kv_start + tl.arange(0, BLOCK_N)
241
+
242
+ acc, l_i, m_i = forward_inner(
243
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5,
244
+ q, K, V,
245
+ desc_k, desc_v, Q_LEN, KV_LEN,
246
+ acc, l_i, m_i,
247
+ off_zq, off_hq, offs_m[:, None], offs_n[None, :],
248
+ kv_start,
249
+ kv_indices, kv_num_blocks,
250
+ 0, block_n_end,
251
+ MATMUL_PRECISION,
252
+ stride_kk, stride_kn, stride_vn, stride_vk,
253
+ IS_FULL_BLOCKS=True,
254
+ )
255
+
256
+
257
+ # [Note] Handle fully masked out rows:
258
+ # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf.
259
+ # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step
260
+ l_i = tl.where(l_i == 0.0, 1, l_i)
261
+
262
+ acc = acc / l_i[:, None]
263
+ idx_zq = tl.program_id(1).to(INDEX_DTYPE)
264
+ idx_hq = tl.program_id(2).to(INDEX_DTYPE)
265
+ idx_m = offs_m[:, None].to(INDEX_DTYPE)
266
+ idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE)
267
+
268
+ mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM)
269
+
270
+ tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED])
271
+ xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0
272
+ tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m + 4096*idx_zq*ks0, acc.shape)), acc, mask)
273
+
274
+ if OUTPUT_LOGSUMEXP:
275
+ off_hz = off_zq * HQ + off_hq
276
+ l_ptrs = LSE + off_hz * Q_LEN + offs_m
277
+ lse = m_i + tl.math.log2(l_i)
278
+ if IS_DIVISIBLE:
279
+ tl.store(l_ptrs, lse)
280
+ else:
281
+ tl.store(l_ptrs, lse, mask=offs_m < Q_LEN)
282
+
283
+ if OUTPUT_MAX:
284
+ off_hz = off_zq * HQ + off_hq
285
+ max_ptrs = MAX + off_hz * Q_LEN + offs_m
286
+ if IS_DIVISIBLE:
287
+ tl.store(max_ptrs, m_i)
288
+ else:
289
+ tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN)
290
+
291
+
292
+ # Utility triton funcs
293
+ @triton.jit
294
+ def get_offset_for_next_block(
295
+ loop_iter, col_indices, total_blocks,
296
+ SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
297
+ BLOCKS_ARE_CONTIGUOUS: tl.constexpr
298
+ ):
299
+ if BLOCKS_ARE_CONTIGUOUS:
300
+ return BLOCK
301
+ cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
302
+ cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
303
+ next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
304
+ needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
305
+ jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
306
+ offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
307
+ return offset
308
+
309
+ @triton.jit
310
+ def get_bounded_indices(indices, max_len=None):
311
+ return indices % max_len if max_len is not None else indices
312
+
313
+ @triton.jit
314
+ def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
315
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
316
+ return tl.load(block_ptr)
317
+ elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
318
+ return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
319
+ elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
320
+ return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
321
+ else:
322
+ return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
323
+
324
+ @triton.jit
325
+ def load_checked_2d(
326
+ ptr,
327
+ offs_m,
328
+ offs_n,
329
+ stride_m,
330
+ stride_n,
331
+ IS_DIVISIBLE_M: tl.constexpr,
332
+ IS_DIVISIBLE_N: tl.constexpr,
333
+ M_LEN: tl.constexpr,
334
+ N_LEN: tl.constexpr,
335
+ ):
336
+ # Calculate final pointer if strides are provided
337
+ if stride_m is not None and stride_n is not None:
338
+ ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
339
+
340
+ # Handle all masking cases
341
+ if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
342
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
343
+ elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
344
+ return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
345
+ elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
346
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
347
+ else: # Both divisible
348
+ return tl.load(ptr)
349
+
350
+
351
+ # Common Imports
352
+ @triton.jit
353
+ def forward_block_mn(
354
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5,
355
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
356
+ # accumulated values
357
+ acc, l_i, m_i,
358
+ # Offsets
359
+ off_z, off_h, offs_m, offs_n,
360
+ # Offsets needed for TMA loads
361
+ kv_start,
362
+ kv_offset,
363
+ MATMUL_PRECISION, RCP_LN2,
364
+ # Strides for K and V
365
+ stride_kk, stride_kn, stride_vn, stride_vk,
366
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
367
+
368
+ ):
369
+ # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
370
+ PRESCALE_QK : tl.constexpr = False
371
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
372
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
373
+ WRITE_DQ : tl.constexpr = True
374
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
375
+ OUTPUT_MAX : tl.constexpr = False
376
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
377
+ IS_DIVISIBLE : tl.constexpr = False
378
+ SM_SCALE : tl.constexpr = 0.08838834764831843
379
+ GQA_SHARED_HEADS : tl.constexpr = 4
380
+ HAS_FULL_BLOCKS : tl.constexpr = True
381
+ QK_HEAD_DIM : tl.constexpr = 128
382
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
383
+ V_HEAD_DIM : tl.constexpr = 128
384
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
385
+ SAFE_HEAD_DIM : tl.constexpr = True
386
+ USE_TMA : tl.constexpr = False
387
+ BLOCK_M : tl.constexpr = 128
388
+ BLOCK_N : tl.constexpr = 64
389
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
390
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
391
+ INDEX_DTYPE : tl.constexpr = tl.int32
392
+
393
+
394
+ # -- load k --
395
+ # NB reversed order to since K is transposed
396
+ kv_base_offset = kv_start + kv_offset
397
+
398
+ # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N]
399
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
400
+ offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N)
401
+ k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
402
+
403
+ k = tl.trans(k)
404
+ # -- compute qk ---
405
+ qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2.
406
+ if not PRESCALE_QK:
407
+ qk *= SM_SCALE
408
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
409
+ # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements,
410
+ # which is larger than the actual number of elements. To avoid access memory out of bound,
411
+ # we need to mask out the elements that are out of Q_LEN & KV_LEN.
412
+ m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None)
413
+ n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None)
414
+
415
+ tmp0 = (qk)
416
+ post_mod_scores = tmp0
417
+
418
+
419
+ if CHECK_BLOCK_BOUNDARY:
420
+ # Mask out the elements that are out of the KV_LEN for non divisible seqlen.
421
+ post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf"))
422
+
423
+ if not IS_FULL_BLOCKS:
424
+ tmp1 = tl.full([1], False, tl.int1)
425
+ tmp2 = (m)
426
+ tmp3 = (n)
427
+ tmp4 = tmp2 >= tmp3
428
+ tmp5 = tmp3.to(tl.int64)
429
+ tmp6 = (off_z)
430
+ tmp7 = tl.load(in_ptr9 + tmp6)
431
+ tmp8 = tmp5 < tmp7
432
+ tmp9 = tmp2.to(tl.int64)
433
+ tmp10 = tmp9 < tmp7
434
+ tmp11 = tmp8 & tmp10
435
+ tmp12 = tmp4 & tmp11
436
+ tmp13 = tmp1 | tmp12
437
+ tmp14 = ks5
438
+ tmp15 = tmp3 >= tmp14
439
+ tmp16 = (tmp3 % tmp14)
440
+ tmp17 = tl.full([1], 0, tl.int32)
441
+ tmp18 = tmp16 != tmp17
442
+ tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0
443
+ tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0
444
+ tmp21 = tmp19 != tmp20
445
+ tmp22 = tmp18 & tmp21
446
+ tmp23 = tmp16 + tmp14
447
+ tmp24 = tl.where(tmp22, tmp23, tmp16)
448
+ tmp25 = tmp24.to(tl.int64)
449
+ tmp26 = tmp25 < tmp7
450
+ tmp27 = tmp15 & tmp26
451
+ tmp28 = tmp3 - tmp2
452
+ tmp29 = (tmp28 % tmp14)
453
+ tmp30 = tmp29 != tmp17
454
+ tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0
455
+ tmp32 = tmp31 != tmp20
456
+ tmp33 = tmp30 & tmp32
457
+ tmp34 = tmp29 + tmp14
458
+ tmp35 = tl.where(tmp33, tmp34, tmp29)
459
+ tmp36 = tmp35 == tmp17
460
+ tmp37 = tmp27 & tmp36
461
+ tmp38 = tmp13 | tmp37
462
+ mask_mod_output = tmp38
463
+
464
+
465
+ if CHECK_BLOCK_BOUNDARY:
466
+ mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False)
467
+ # apply mask for partially unmasked blocks
468
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
469
+
470
+ if not PRESCALE_QK:
471
+ post_mod_scores *= RCP_LN2
472
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
473
+
474
+ # -- compute scaling constant ---
475
+ m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1))
476
+ if not ROWS_GUARANTEED_SAFE:
477
+ masked_out_rows = (m_ij == float("-inf"))
478
+ m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
479
+ else:
480
+ m_ij_masked = m_ij
481
+
482
+ alpha = tl.math.exp2(m_i - m_ij_masked)
483
+ p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None])
484
+
485
+ # NB: l_i update is pulled up here since it's a bit faster
486
+ # NB: For headdim=256, it's faster to move it back down to after m_i =
487
+ # m_ij
488
+ l_i = l_i * alpha + tl.sum(p, 1)
489
+ # # -- scale and update acc --
490
+ acc = acc * alpha[:, None]
491
+ # Calculate offsets for V loading - reuse kv_base_offset from K loading
492
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
493
+ v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
494
+ acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION)
495
+
496
+ # -- update m_i
497
+ m_i = m_ij
498
+
499
+ return acc, l_i, m_i
500
+
501
+ @triton.jit
502
+ def forward_inner(
503
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5,
504
+ q, K, V,
505
+ desc_k, desc_v, Q_LEN, KV_LEN,
506
+ # accumulated values
507
+ acc, l_i, m_i,
508
+ # Offsets used as inputs to score_mod & mask_mod
509
+ # of size [BLOCK_M, BLOCK_N] or scalar.
510
+ off_z, off_h, offs_m, offs_n,
511
+ # Offsets needed for TMA loads
512
+ kv_start,
513
+ # blocksparse data
514
+ kv_indices, kv_num_blocks,
515
+ # start kv and end kv block
516
+ block_n_start, block_n_end,
517
+ MATMUL_PRECISION,
518
+ # Strides for K and V
519
+ stride_kk, stride_kn, stride_vn, stride_vk,
520
+ IS_FULL_BLOCKS,
521
+ ):
522
+ # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
523
+ PRESCALE_QK : tl.constexpr = False
524
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
525
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
526
+ WRITE_DQ : tl.constexpr = True
527
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
528
+ OUTPUT_MAX : tl.constexpr = False
529
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
530
+ IS_DIVISIBLE : tl.constexpr = False
531
+ SM_SCALE : tl.constexpr = 0.08838834764831843
532
+ GQA_SHARED_HEADS : tl.constexpr = 4
533
+ HAS_FULL_BLOCKS : tl.constexpr = True
534
+ QK_HEAD_DIM : tl.constexpr = 128
535
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
536
+ V_HEAD_DIM : tl.constexpr = 128
537
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
538
+ SAFE_HEAD_DIM : tl.constexpr = True
539
+ USE_TMA : tl.constexpr = False
540
+ BLOCK_M : tl.constexpr = 128
541
+ BLOCK_N : tl.constexpr = 64
542
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
543
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
544
+ INDEX_DTYPE : tl.constexpr = tl.int32
545
+
546
+
547
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
548
+ RCP_LN2: tl.constexpr = 1.44269504
549
+
550
+ if PRESCALE_QK:
551
+ q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
552
+
553
+ kv_offset = 0
554
+
555
+ # loop over k, v and update accumulator until block_n_end
556
+ for start_n in range(block_n_start, block_n_end):
557
+ # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention.
558
+ if IS_DIVISIBLE:
559
+ acc, l_i, m_i = forward_block_mn(
560
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5,
561
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
562
+ # accumulated values
563
+ acc, l_i, m_i,
564
+ # Offsets
565
+ off_z, off_h, offs_m, offs_n,
566
+ # Offsets needed for TMA loads
567
+ kv_start,
568
+ kv_offset,
569
+ MATMUL_PRECISION, RCP_LN2,
570
+ # Strides for K and V
571
+ stride_kk, stride_kn, stride_vn, stride_vk,
572
+ IS_FULL_BLOCKS,
573
+ )
574
+ else:
575
+ # Benchmark shows even we applied mod & mask to each block for non divisible seqlen,
576
+ # it's on par or slightly faster than only applying to the last block in fwd.
577
+ # However, we choose different strategy for bwd, where we only apply mod & mask
578
+ # to the last block because it's faster a lot.
579
+ acc, l_i, m_i = forward_block_mn(
580
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5,
581
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
582
+ # accumulated values
583
+ acc, l_i, m_i,
584
+ # Offsets
585
+ off_z, off_h, offs_m, offs_n,
586
+ # Offsets needed for TMA loads
587
+ kv_start,
588
+ kv_offset,
589
+ MATMUL_PRECISION, RCP_LN2,
590
+ # Strides for K and V
591
+ stride_kk, stride_kn, stride_vn, stride_vk,
592
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
593
+ )
594
+
595
+
596
+
597
+ offset = get_offset_for_next_block(
598
+ start_n, kv_indices, kv_num_blocks,
599
+ SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS
600
+ )
601
+
602
+ offs_n = offs_n + offset
603
+ kv_offset += offset
604
+
605
+
606
+ return acc, l_i, m_i
607
+ ''', device_str='cuda')
608
+
609
+
610
+ async_compile.wait(globals())
611
+ del async_compile
612
+
613
+ class Runner:
614
+ def __init__(self, partitions):
615
+ self.partitions = partitions
616
+
617
+ def recursively_apply_fns(self, fns):
618
+ new_callables = []
619
+ for fn, c in zip(fns, self.partitions):
620
+ new_callables.append(fn(c))
621
+ self.partitions = new_callables
622
+
623
+ def call(self, args):
624
+ primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21, primals_22, primals_23, primals_24, primals_25, primals_26, primals_27, primals_28, primals_29, primals_30 = args
625
+ args.clear()
626
+ s50 = primals_1
627
+ s0 = primals_3
628
+ s43 = primals_5
629
+ s22 = primals_7
630
+ s72 = primals_8
631
+ s37 = primals_10
632
+ s71 = primals_11
633
+ s99 = primals_12
634
+ s75 = primals_15
635
+ s94 = primals_16
636
+ s28 = primals_18
637
+ s4 = primals_19
638
+ s56 = primals_21
639
+ s84 = primals_23
640
+ s53 = primals_24
641
+ s100 = primals_26
642
+ s6 = primals_28
643
+ s10 = primals_29
644
+ assert_size_stride(primals_2, (2, 32, s37, 128), (4096*s37, 128, 4096, 1))
645
+ assert_size_stride(primals_4, (2, 8, s0, 128), (1024*s0, 128*s0, 128, 1))
646
+ assert_size_stride(primals_6, (2, 8, s0, 128), (1024*s0, 128*s0, 128, 1))
647
+ assert_size_stride(primals_9, (2, 1, s22, s72), (s22*s72, s22*s72, s72, 1))
648
+ assert_size_stride(primals_13, (2, 1, s99), (s99, s99, 1))
649
+ assert_size_stride(primals_14, (2, ), (1, ))
650
+ assert_size_stride(primals_17, (2, 1, s94), (s94, s94, 1))
651
+ assert_size_stride(primals_20, (2, 1, s28, s4), (s28*s4, s28*s4, s4, 1))
652
+ assert_size_stride(primals_22, (2, 1, s56), (s56, s56, 1))
653
+ assert_size_stride(primals_25, (2, 1, s84, s53), (s53*s84, s53*s84, s53, 1))
654
+ assert_size_stride(primals_27, (2, 1, s100), (s100, s100, 1))
655
+ assert_size_stride(primals_30, (2, 1, s6, s10), (s10*s6, s10*s6, s10, 1))
656
+ with torch.cuda._DeviceGuard(7):
657
+ torch.cuda.set_device(7)
658
+ buf0 = empty_strided_cuda((2, 32, s37), (32*s37, s37, 1), torch.float32)
659
+ buf1 = empty_strided_cuda((2, 32, s37), (32*s37, s37, 1), torch.float32)
660
+ buf2 = empty_strided_cuda((2, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16)
661
+ # Topologically Sorted Source Nodes: [flex_attention], Original ATen: []
662
+ stream7 = get_raw_stream(7)
663
+ triton_tem_fused_0.run(primals_2, primals_4, primals_6, buf0, buf1, primals_13, primals_9, primals_17, primals_20, primals_14, buf2, s37, s0, s99, s22, s72, s75, (127 + s37) // 128, 2, 32, stream=stream7)
664
+ del buf1
665
+ return (buf2, primals_2, primals_4, primals_6, primals_9, primals_13, primals_14, primals_17, primals_20, primals_22, primals_25, primals_27, primals_30, buf2, buf0, s37, s0, s75, s22, s72, s99, s94, s28, s4, s56, s53, s84, s100, s10, s6, )
666
+
667
+ runner = Runner(partitions=[])
668
+ call = runner.call
669
+ recursively_apply_fns = runner.recursively_apply_fns
670
+
671
+
672
+ def benchmark_compiled_module(times=10, repeat=10):
673
+ from torch._dynamo.testing import rand_strided
674
+ from torch._inductor.utils import print_performance
675
+ primals_1 = 1904
676
+ primals_2 = rand_strided((2, 32, 1904, 128), (7798784, 128, 4096, 1), device='cuda:7', dtype=torch.bfloat16)
677
+ primals_3 = 1904
678
+ primals_4 = rand_strided((2, 8, 1904, 128), (1949696, 243712, 128, 1), device='cuda:7', dtype=torch.bfloat16)
679
+ primals_5 = 1904
680
+ primals_6 = rand_strided((2, 8, 1904, 128), (1949696, 243712, 128, 1), device='cuda:7', dtype=torch.bfloat16)
681
+ primals_7 = 15
682
+ primals_8 = 15
683
+ primals_9 = rand_strided((2, 1, 15, 15), (225, 225, 15, 1), device='cuda:7', dtype=torch.int32)
684
+ primals_10 = 1904
685
+ primals_11 = 1904
686
+ primals_12 = 15
687
+ primals_13 = rand_strided((2, 1, 15), (15, 15, 1), device='cuda:7', dtype=torch.int32)
688
+ primals_14 = rand_strided((2, ), (1, ), device='cuda:7', dtype=torch.int64)
689
+ primals_15 = 1904
690
+ primals_16 = 15
691
+ primals_17 = rand_strided((2, 1, 15), (15, 15, 1), device='cuda:7', dtype=torch.int32)
692
+ primals_18 = 15
693
+ primals_19 = 15
694
+ primals_20 = rand_strided((2, 1, 15, 15), (225, 225, 15, 1), device='cuda:7', dtype=torch.int32)
695
+ primals_21 = 15
696
+ primals_22 = rand_strided((2, 1, 15), (15, 15, 1), device='cuda:7', dtype=torch.int32)
697
+ primals_23 = 15
698
+ primals_24 = 15
699
+ primals_25 = rand_strided((2, 1, 15, 15), (225, 225, 15, 1), device='cuda:7', dtype=torch.int32)
700
+ primals_26 = 15
701
+ primals_27 = rand_strided((2, 1, 15), (15, 15, 1), device='cuda:7', dtype=torch.int32)
702
+ primals_28 = 15
703
+ primals_29 = 15
704
+ primals_30 = rand_strided((2, 1, 15, 15), (225, 225, 15, 1), device='cuda:7', dtype=torch.int32)
705
+ fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21, primals_22, primals_23, primals_24, primals_25, primals_26, primals_27, primals_28, primals_29, primals_30])
706
+ return print_performance(fn, times=times, repeat=repeat)
707
+
708
+
709
+ if __name__ == "__main__":
710
+ from torch._inductor.wrapper_benchmark import compiled_module_main
711
+ compiled_module_main('None', benchmark_compiled_module)
SpecForge-ext/cache/compiled_kernels/at/2dfb5ffb77d217b8298333b84d6362971879c20614915aac57601c1f150ac07b.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 128, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 11, "triton_cache_hash": "IK5RT3JGLTF5PMMUH32NIWB2GXNU6R6CGIZSCRHU3I65YM226KDA"}
SpecForge-ext/cache/compiled_kernels/at/cat6f3b7vbc3opxxrqwtgyrnap7msqfa5gw45bly56fm7xfzsng7.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.pointwise(
11
+ size_hints={'x': 2048},
12
+ filename=__file__,
13
+ triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr0': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_6', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
15
+ min_elem_per_thread=0
16
+ )
17
+ @triton.jit
18
+ def triton_poi_fused__to_copy_6(in_ptr0, out_ptr0, ks0, ks1, ks2, xnumel, XBLOCK : tl.constexpr):
19
+ xoffset = tl.program_id(0) * XBLOCK
20
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
21
+ xmask = xindex < xnumel
22
+ x0 = (xindex % ks0)
23
+ x1 = ((xindex // ks0) % ks1)
24
+ x2 = xindex // ks2
25
+ tmp0 = tl.load(in_ptr0 + (x1 + x0*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1))) + x2*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1)))), xmask, eviction_policy='evict_last')
26
+ tmp1 = tmp0.to(tl.int32)
27
+ tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))) + x2*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1)))), tmp1, xmask)
SpecForge-ext/cache/compiled_kernels/at/catnwworbo47zz5uux2qx6gtvq5zrkdmzm5qpt64msmr3cjlnoz5.py ADDED
@@ -0,0 +1,675 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AOT ID: ['6_forward']
2
+ from ctypes import c_void_p, c_long, c_int
3
+ import torch
4
+ import math
5
+ import random
6
+ import os
7
+ import tempfile
8
+ from math import inf, nan
9
+ from cmath import nanj
10
+ from torch._inductor.hooks import run_intermediate_hooks
11
+ from torch._inductor.utils import maybe_profile
12
+ from torch._inductor.codegen.memory_planning import _align as align
13
+ from torch import device, empty_strided
14
+ from torch._inductor.async_compile import AsyncCompile
15
+ from torch._inductor.select_algorithm import extern_kernels
16
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
17
+ import triton
18
+ import triton.language as tl
19
+ from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
20
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
21
+
22
+ aten = torch.ops.aten
23
+ inductor_ops = torch.ops.inductor
24
+ _quantized = torch.ops._quantized
25
+ assert_size_stride = torch._C._dynamo.guards.assert_size_stride
26
+ assert_alignment = torch._C._dynamo.guards.assert_alignment
27
+ empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
28
+ empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
29
+ empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
30
+ empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
31
+ empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
32
+ reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
33
+ alloc_from_pool = torch.ops.inductor._alloc_from_pool
34
+ async_compile = AsyncCompile()
35
+ empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
36
+
37
+
38
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/en/cenh5uz42ng4lj7xw7veh7qtahkm73nfwpjlgreomiruz4qp4l5j.py
39
+ # Topologically Sorted Source Nodes: [flex_attention], Original ATen: []
40
+ # Source node to ATen node mapping:
41
+ # flex_attention => flex_attention
42
+ # Graph fragment:
43
+ # %primals_1 : Tensor "bf16[8, 32, 2048, 128][8388608, 128, 4096, 1]cuda:0" = PlaceHolder[target=primals_1]
44
+ # %primals_2 : Tensor "bf16[8, 8, 2048, 128][2097152, 262144, 128, 1]cuda:0" = PlaceHolder[target=primals_2]
45
+ # %primals_3 : Tensor "bf16[8, 8, 2048, 128][2097152, 262144, 128, 1]cuda:0" = PlaceHolder[target=primals_3]
46
+ # %getitem_1 : Tensor "f32[8, 32, 2048][65536, 2048, 1]cuda:0" = PlaceHolder[target=getitem_1]
47
+ # %buf1 : Tensor "f32[8, 32, 2048][65536, 2048, 1]cuda:0" = PlaceHolder[target=buf1]
48
+ # %primals_5 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:0" = PlaceHolder[target=primals_5]
49
+ # %primals_4 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:0" = PlaceHolder[target=primals_4]
50
+ # %primals_7 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:0" = PlaceHolder[target=primals_7]
51
+ # %primals_8 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:0" = PlaceHolder[target=primals_8]
52
+ # %primals_6 : Tensor "i64[8][1]cuda:0" = PlaceHolder[target=primals_6]
53
+ # %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%primals_1, %primals_2, %primals_3, %sdpa_score0, (2048, 2048, %primals_5, %primals_4, %primals_7, %primals_8, %primals_9, %primals_10, %primals_11, %primals_12, 128, 128, %sdpa_mask0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_6,)), kwargs = {})
54
+ # return %getitem
55
+ triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', '''
56
+ import triton
57
+ import triton.language as tl
58
+
59
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
60
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
61
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
62
+
63
+ @triton_heuristics.template(
64
+
65
+ num_stages=3,
66
+ num_warps=8,
67
+ triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'in_ptr9': '*i64', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]},
68
+ inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
69
+
70
+ )
71
+ @triton.jit
72
+ def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0):
73
+ PRESCALE_QK : tl.constexpr = False
74
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
75
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
76
+ WRITE_DQ : tl.constexpr = True
77
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
78
+ OUTPUT_MAX : tl.constexpr = False
79
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
80
+ IS_DIVISIBLE : tl.constexpr = True
81
+ SM_SCALE : tl.constexpr = 0.08838834764831843
82
+ GQA_SHARED_HEADS : tl.constexpr = 4
83
+ HAS_FULL_BLOCKS : tl.constexpr = True
84
+ QK_HEAD_DIM : tl.constexpr = 128
85
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
86
+ V_HEAD_DIM : tl.constexpr = 128
87
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
88
+ SAFE_HEAD_DIM : tl.constexpr = True
89
+ USE_TMA : tl.constexpr = False
90
+ BLOCK_M : tl.constexpr = 128
91
+ BLOCK_N : tl.constexpr = 64
92
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
93
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
94
+ INDEX_DTYPE : tl.constexpr = tl.int32
95
+ Q = arg_Q
96
+ K = arg_K
97
+ V = arg_V
98
+ LSE = arg_LSE
99
+ MAX = arg_MAX
100
+ KV_NUM_BLKS = arg_KV_NUM_BLKS
101
+ KV_IDX = arg_KV_IDX
102
+ FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
103
+ FULL_KV_IDX = arg_FULL_KV_IDX
104
+
105
+ # Sub notation for this kernel:
106
+ #
107
+ # Q: Query, K: Key, V: Value
108
+ # M: Number of queries, N: Number of keys/values, D: Model dimension
109
+ # QK_HEAD_DIM: The dimension of the query and key embeddings
110
+ # V_HEAD_DIM: The dimension of the value embeddings
111
+ # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head
112
+ # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
113
+ #
114
+ # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
115
+ # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
116
+ # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
117
+ # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
118
+ # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
119
+ #
120
+ # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad
121
+ #
122
+ # (Modifiable) Performance tuning options
123
+ # BLOCK_M: The thread block size across the seqlen dim of Q.
124
+ # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block.
125
+
126
+ # The below are kernel options that can be applied for certain score_mods,
127
+ # or involve a numerics vs. perf tradeoff
128
+ # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
129
+ # about 20% more numerical error, but slightly faster.
130
+ # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row
131
+ # is not masked out? If so, we can skip an extra safety check
132
+ # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are
133
+ # contiguous? If so, we don't need to do an indirect jump for every block
134
+
135
+ tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0)
136
+ tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0)
137
+
138
+ # Define strides of inputs
139
+ stride_qz, stride_qh, stride_qm, stride_qk = 8388608, 128, 4096, 1
140
+ stride_kz, stride_kh, stride_kn, stride_kk = 2097152, 262144, 128, 1
141
+ stride_vz, stride_vh, stride_vn, stride_vk = 2097152, 262144, 128, 1
142
+
143
+ ZQ = 8
144
+ HQ = 32
145
+ Q_LEN = 2048
146
+ ZKV = 8
147
+ KV_LEN = 2048
148
+
149
+ MATMUL_PRECISION = Q.dtype.element_ty
150
+
151
+ q_start = tl.program_id(0).to(INDEX_DTYPE)
152
+ off_zq = tl.program_id(1).to(INDEX_DTYPE)
153
+ off_hq = tl.program_id(2).to(INDEX_DTYPE)
154
+
155
+ # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq.
156
+ # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0.
157
+ off_zkv = off_zq % ZKV
158
+ off_hkv = off_hq // GQA_SHARED_HEADS
159
+ off_g = off_hq % GQA_SHARED_HEADS
160
+
161
+ q_offset = off_zq * stride_qz + off_hq * stride_qh
162
+ k_offset = off_zkv * stride_kz + off_hkv * stride_kh
163
+ v_offset = off_zkv * stride_vz + off_hkv * stride_vh
164
+
165
+ Q = Q + q_offset
166
+ K = K + k_offset
167
+ V = V + v_offset
168
+
169
+ # Setting up the TMA descriptors for Q, K, V
170
+ desc_q = None
171
+ desc_k = None
172
+ desc_v = None
173
+
174
+ SPARSE_Z = 8
175
+ SPARSE_HQ = 1
176
+
177
+ sparse_idx_z = off_zq % SPARSE_Z
178
+ sparse_idx_hq = off_hq % SPARSE_HQ
179
+
180
+ SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M)
181
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
182
+
183
+ stride_kv_num_blks_h = 16
184
+ stride_kv_idx_h = 256
185
+ stride_kv_idx_m = 16
186
+
187
+ # initialize pointer to m and l
188
+ m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
189
+ l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
190
+ acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
191
+
192
+ offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
193
+
194
+ # KV_IDX and KV_NUM_BLKS are always contiguous.
195
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq
196
+ sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE
197
+ sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950
198
+ offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
199
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
200
+ q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
201
+
202
+ # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
203
+ # We don't know anything "special" about these blocks, so we need to apply
204
+ # both score_mod and mask_mod to it
205
+ kv_indices = KV_IDX + sparse_kv_idx_offset
206
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
207
+ kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
208
+ block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
209
+
210
+
211
+ # K and V pointers will be passed directly to forward_inner
212
+
213
+ offs_n = kv_start + tl.arange(0, BLOCK_N)
214
+
215
+
216
+ acc, l_i, m_i = forward_inner(
217
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0,
218
+ q, K, V,
219
+ desc_k, desc_v, Q_LEN, KV_LEN,
220
+ acc, l_i, m_i,
221
+ off_zq, off_hq, offs_m[:, None], offs_n[None, :],
222
+ kv_start,
223
+ kv_indices, kv_num_blocks,
224
+ 0, block_n_end,
225
+ MATMUL_PRECISION,
226
+ stride_kk, stride_kn, stride_vn, stride_vk,
227
+ IS_FULL_BLOCKS=False,
228
+ )
229
+
230
+ # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
231
+ # We know these blocks are guaranteed to be "full", so we don't need to
232
+ # apply mask_mod to them - only score_mod
233
+ if HAS_FULL_BLOCKS:
234
+ # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
235
+ kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
236
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
237
+ kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
238
+ block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
239
+ # K and V pointers will be passed directly to forward_inner
240
+ offs_n = kv_start + tl.arange(0, BLOCK_N)
241
+
242
+ acc, l_i, m_i = forward_inner(
243
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0,
244
+ q, K, V,
245
+ desc_k, desc_v, Q_LEN, KV_LEN,
246
+ acc, l_i, m_i,
247
+ off_zq, off_hq, offs_m[:, None], offs_n[None, :],
248
+ kv_start,
249
+ kv_indices, kv_num_blocks,
250
+ 0, block_n_end,
251
+ MATMUL_PRECISION,
252
+ stride_kk, stride_kn, stride_vn, stride_vk,
253
+ IS_FULL_BLOCKS=True,
254
+ )
255
+
256
+
257
+ # [Note] Handle fully masked out rows:
258
+ # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf.
259
+ # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step
260
+ l_i = tl.where(l_i == 0.0, 1, l_i)
261
+
262
+ acc = acc / l_i[:, None]
263
+ idx_zq = tl.program_id(1).to(INDEX_DTYPE)
264
+ idx_hq = tl.program_id(2).to(INDEX_DTYPE)
265
+ idx_m = offs_m[:, None].to(INDEX_DTYPE)
266
+ idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE)
267
+
268
+ mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM)
269
+
270
+ tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED])
271
+ xindex = idx_d + 128*idx_m + 262144*idx_hq + 8388608*idx_zq
272
+ tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m + 8388608*idx_zq, acc.shape)), acc, mask)
273
+
274
+ if OUTPUT_LOGSUMEXP:
275
+ off_hz = off_zq * HQ + off_hq
276
+ l_ptrs = LSE + off_hz * Q_LEN + offs_m
277
+ lse = m_i + tl.math.log2(l_i)
278
+ if IS_DIVISIBLE:
279
+ tl.store(l_ptrs, lse)
280
+ else:
281
+ tl.store(l_ptrs, lse, mask=offs_m < Q_LEN)
282
+
283
+ if OUTPUT_MAX:
284
+ off_hz = off_zq * HQ + off_hq
285
+ max_ptrs = MAX + off_hz * Q_LEN + offs_m
286
+ if IS_DIVISIBLE:
287
+ tl.store(max_ptrs, m_i)
288
+ else:
289
+ tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN)
290
+
291
+
292
+ # Utility triton funcs
293
+ @triton.jit
294
+ def get_offset_for_next_block(
295
+ loop_iter, col_indices, total_blocks,
296
+ SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
297
+ BLOCKS_ARE_CONTIGUOUS: tl.constexpr
298
+ ):
299
+ if BLOCKS_ARE_CONTIGUOUS:
300
+ return BLOCK
301
+ cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
302
+ cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
303
+ next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
304
+ needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
305
+ jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
306
+ offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
307
+ return offset
308
+
309
+ @triton.jit
310
+ def get_bounded_indices(indices, max_len=None):
311
+ return indices % max_len if max_len is not None else indices
312
+
313
+ @triton.jit
314
+ def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
315
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
316
+ return tl.load(block_ptr)
317
+ elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
318
+ return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
319
+ elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
320
+ return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
321
+ else:
322
+ return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
323
+
324
+ @triton.jit
325
+ def load_checked_2d(
326
+ ptr,
327
+ offs_m,
328
+ offs_n,
329
+ stride_m,
330
+ stride_n,
331
+ IS_DIVISIBLE_M: tl.constexpr,
332
+ IS_DIVISIBLE_N: tl.constexpr,
333
+ M_LEN: tl.constexpr,
334
+ N_LEN: tl.constexpr,
335
+ ):
336
+ # Calculate final pointer if strides are provided
337
+ if stride_m is not None and stride_n is not None:
338
+ ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
339
+
340
+ # Handle all masking cases
341
+ if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
342
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
343
+ elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
344
+ return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
345
+ elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
346
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
347
+ else: # Both divisible
348
+ return tl.load(ptr)
349
+
350
+
351
+ # Common Imports
352
+ @triton.jit
353
+ def forward_block_mn(
354
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0,
355
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
356
+ # accumulated values
357
+ acc, l_i, m_i,
358
+ # Offsets
359
+ off_z, off_h, offs_m, offs_n,
360
+ # Offsets needed for TMA loads
361
+ kv_start,
362
+ kv_offset,
363
+ MATMUL_PRECISION, RCP_LN2,
364
+ # Strides for K and V
365
+ stride_kk, stride_kn, stride_vn, stride_vk,
366
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
367
+
368
+ ):
369
+ # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
370
+ PRESCALE_QK : tl.constexpr = False
371
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
372
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
373
+ WRITE_DQ : tl.constexpr = True
374
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
375
+ OUTPUT_MAX : tl.constexpr = False
376
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
377
+ IS_DIVISIBLE : tl.constexpr = True
378
+ SM_SCALE : tl.constexpr = 0.08838834764831843
379
+ GQA_SHARED_HEADS : tl.constexpr = 4
380
+ HAS_FULL_BLOCKS : tl.constexpr = True
381
+ QK_HEAD_DIM : tl.constexpr = 128
382
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
383
+ V_HEAD_DIM : tl.constexpr = 128
384
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
385
+ SAFE_HEAD_DIM : tl.constexpr = True
386
+ USE_TMA : tl.constexpr = False
387
+ BLOCK_M : tl.constexpr = 128
388
+ BLOCK_N : tl.constexpr = 64
389
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
390
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
391
+ INDEX_DTYPE : tl.constexpr = tl.int32
392
+
393
+
394
+ # -- load k --
395
+ # NB reversed order to since K is transposed
396
+ kv_base_offset = kv_start + kv_offset
397
+
398
+ # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N]
399
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
400
+ offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N)
401
+ k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
402
+
403
+ k = tl.trans(k)
404
+ # -- compute qk ---
405
+ qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2.
406
+ if not PRESCALE_QK:
407
+ qk *= SM_SCALE
408
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
409
+ # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements,
410
+ # which is larger than the actual number of elements. To avoid access memory out of bound,
411
+ # we need to mask out the elements that are out of Q_LEN & KV_LEN.
412
+ m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None)
413
+ n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None)
414
+
415
+ tmp0 = (qk)
416
+ post_mod_scores = tmp0
417
+
418
+
419
+ if CHECK_BLOCK_BOUNDARY:
420
+ # Mask out the elements that are out of the KV_LEN for non divisible seqlen.
421
+ post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf"))
422
+
423
+ if not IS_FULL_BLOCKS:
424
+ tmp1 = tl.full([1], False, tl.int1)
425
+ tmp2 = (m)
426
+ tmp3 = (n)
427
+ tmp4 = tmp2 >= tmp3
428
+ tmp5 = tmp3.to(tl.int64)
429
+ tmp6 = (off_z)
430
+ tmp7 = tl.load(in_ptr9 + tmp6)
431
+ tmp8 = tmp5 < tmp7
432
+ tmp9 = tmp2.to(tl.int64)
433
+ tmp10 = tmp9 < tmp7
434
+ tmp11 = tmp8 & tmp10
435
+ tmp12 = tmp4 & tmp11
436
+ tmp13 = tmp1 | tmp12
437
+ tmp14 = tl.full([1], 2048, tl.int32)
438
+ tmp15 = tmp3 >= tmp14
439
+ tmp16 = (tmp3 % tmp14)
440
+ tmp17 = tl.full([1], 0, tl.int32)
441
+ tmp18 = tmp16 != tmp17
442
+ tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0
443
+ tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0
444
+ tmp21 = tmp19 != tmp20
445
+ tmp22 = tmp18 & tmp21
446
+ tmp23 = tmp16 + tmp14
447
+ tmp24 = tl.where(tmp22, tmp23, tmp16)
448
+ tmp25 = tmp24.to(tl.int64)
449
+ tmp26 = tmp25 < tmp7
450
+ tmp27 = tmp15 & tmp26
451
+ tmp28 = tmp3 - tmp2
452
+ tmp29 = (tmp28 % tmp14)
453
+ tmp30 = tmp29 != tmp17
454
+ tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0
455
+ tmp32 = tmp31 != tmp20
456
+ tmp33 = tmp30 & tmp32
457
+ tmp34 = tmp29 + tmp14
458
+ tmp35 = tl.where(tmp33, tmp34, tmp29)
459
+ tmp36 = tmp35 == tmp17
460
+ tmp37 = tmp27 & tmp36
461
+ tmp38 = tmp13 | tmp37
462
+ mask_mod_output = tmp38
463
+
464
+
465
+ if CHECK_BLOCK_BOUNDARY:
466
+ mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False)
467
+ # apply mask for partially unmasked blocks
468
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
469
+
470
+ if not PRESCALE_QK:
471
+ post_mod_scores *= RCP_LN2
472
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
473
+
474
+ # -- compute scaling constant ---
475
+ m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1))
476
+ if not ROWS_GUARANTEED_SAFE:
477
+ masked_out_rows = (m_ij == float("-inf"))
478
+ m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
479
+ else:
480
+ m_ij_masked = m_ij
481
+
482
+ alpha = tl.math.exp2(m_i - m_ij_masked)
483
+ p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None])
484
+
485
+ # NB: l_i update is pulled up here since it's a bit faster
486
+ # NB: For headdim=256, it's faster to move it back down to after m_i =
487
+ # m_ij
488
+ l_i = l_i * alpha + tl.sum(p, 1)
489
+ # # -- scale and update acc --
490
+ acc = acc * alpha[:, None]
491
+ # Calculate offsets for V loading - reuse kv_base_offset from K loading
492
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
493
+ v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
494
+ acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION)
495
+
496
+ # -- update m_i
497
+ m_i = m_ij
498
+
499
+ return acc, l_i, m_i
500
+
501
+ @triton.jit
502
+ def forward_inner(
503
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0,
504
+ q, K, V,
505
+ desc_k, desc_v, Q_LEN, KV_LEN,
506
+ # accumulated values
507
+ acc, l_i, m_i,
508
+ # Offsets used as inputs to score_mod & mask_mod
509
+ # of size [BLOCK_M, BLOCK_N] or scalar.
510
+ off_z, off_h, offs_m, offs_n,
511
+ # Offsets needed for TMA loads
512
+ kv_start,
513
+ # blocksparse data
514
+ kv_indices, kv_num_blocks,
515
+ # start kv and end kv block
516
+ block_n_start, block_n_end,
517
+ MATMUL_PRECISION,
518
+ # Strides for K and V
519
+ stride_kk, stride_kn, stride_vn, stride_vk,
520
+ IS_FULL_BLOCKS,
521
+ ):
522
+ # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
523
+ PRESCALE_QK : tl.constexpr = False
524
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
525
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
526
+ WRITE_DQ : tl.constexpr = True
527
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
528
+ OUTPUT_MAX : tl.constexpr = False
529
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
530
+ IS_DIVISIBLE : tl.constexpr = True
531
+ SM_SCALE : tl.constexpr = 0.08838834764831843
532
+ GQA_SHARED_HEADS : tl.constexpr = 4
533
+ HAS_FULL_BLOCKS : tl.constexpr = True
534
+ QK_HEAD_DIM : tl.constexpr = 128
535
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
536
+ V_HEAD_DIM : tl.constexpr = 128
537
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
538
+ SAFE_HEAD_DIM : tl.constexpr = True
539
+ USE_TMA : tl.constexpr = False
540
+ BLOCK_M : tl.constexpr = 128
541
+ BLOCK_N : tl.constexpr = 64
542
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
543
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
544
+ INDEX_DTYPE : tl.constexpr = tl.int32
545
+
546
+
547
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
548
+ RCP_LN2: tl.constexpr = 1.44269504
549
+
550
+ if PRESCALE_QK:
551
+ q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
552
+
553
+ kv_offset = 0
554
+
555
+ # loop over k, v and update accumulator until block_n_end
556
+ for start_n in range(block_n_start, block_n_end):
557
+ # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention.
558
+ if IS_DIVISIBLE:
559
+ acc, l_i, m_i = forward_block_mn(
560
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0,
561
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
562
+ # accumulated values
563
+ acc, l_i, m_i,
564
+ # Offsets
565
+ off_z, off_h, offs_m, offs_n,
566
+ # Offsets needed for TMA loads
567
+ kv_start,
568
+ kv_offset,
569
+ MATMUL_PRECISION, RCP_LN2,
570
+ # Strides for K and V
571
+ stride_kk, stride_kn, stride_vn, stride_vk,
572
+ IS_FULL_BLOCKS,
573
+ )
574
+ else:
575
+ # Benchmark shows even we applied mod & mask to each block for non divisible seqlen,
576
+ # it's on par or slightly faster than only applying to the last block in fwd.
577
+ # However, we choose different strategy for bwd, where we only apply mod & mask
578
+ # to the last block because it's faster a lot.
579
+ acc, l_i, m_i = forward_block_mn(
580
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0,
581
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
582
+ # accumulated values
583
+ acc, l_i, m_i,
584
+ # Offsets
585
+ off_z, off_h, offs_m, offs_n,
586
+ # Offsets needed for TMA loads
587
+ kv_start,
588
+ kv_offset,
589
+ MATMUL_PRECISION, RCP_LN2,
590
+ # Strides for K and V
591
+ stride_kk, stride_kn, stride_vn, stride_vk,
592
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
593
+ )
594
+
595
+
596
+
597
+ offset = get_offset_for_next_block(
598
+ start_n, kv_indices, kv_num_blocks,
599
+ SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS
600
+ )
601
+
602
+ offs_n = offs_n + offset
603
+ kv_offset += offset
604
+
605
+
606
+ return acc, l_i, m_i
607
+ ''', device_str='cuda')
608
+
609
+
610
+ async_compile.wait(globals())
611
+ del async_compile
612
+
613
+ class Runner:
614
+ def __init__(self, partitions):
615
+ self.partitions = partitions
616
+
617
+ def recursively_apply_fns(self, fns):
618
+ new_callables = []
619
+ for fn, c in zip(fns, self.partitions):
620
+ new_callables.append(fn(c))
621
+ self.partitions = new_callables
622
+
623
+ def call(self, args):
624
+ primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12 = args
625
+ args.clear()
626
+ assert_size_stride(primals_1, (8, 32, 2048, 128), (8388608, 128, 4096, 1))
627
+ assert_size_stride(primals_2, (8, 8, 2048, 128), (2097152, 262144, 128, 1))
628
+ assert_size_stride(primals_3, (8, 8, 2048, 128), (2097152, 262144, 128, 1))
629
+ assert_size_stride(primals_4, (8, 1, 16, 16), (256, 256, 16, 1))
630
+ assert_size_stride(primals_5, (8, 1, 16), (16, 16, 1))
631
+ assert_size_stride(primals_6, (8, ), (1, ))
632
+ assert_size_stride(primals_7, (8, 1, 16), (16, 16, 1))
633
+ assert_size_stride(primals_8, (8, 1, 16, 16), (256, 256, 16, 1))
634
+ assert_size_stride(primals_9, (8, 1, 16), (16, 16, 1))
635
+ assert_size_stride(primals_10, (8, 1, 16, 16), (256, 256, 16, 1))
636
+ assert_size_stride(primals_11, (8, 1, 16), (16, 16, 1))
637
+ assert_size_stride(primals_12, (8, 1, 16, 16), (256, 256, 16, 1))
638
+ with torch.cuda._DeviceGuard(0):
639
+ torch.cuda.set_device(0)
640
+ buf0 = empty_strided_cuda((8, 32, 2048), (65536, 2048, 1), torch.float32)
641
+ buf1 = empty_strided_cuda((8, 32, 2048), (65536, 2048, 1), torch.float32)
642
+ buf2 = empty_strided_cuda((8, 32, 2048, 128), (8388608, 128, 4096, 1), torch.bfloat16)
643
+ # Topologically Sorted Source Nodes: [flex_attention], Original ATen: []
644
+ stream0 = get_raw_stream(0)
645
+ triton_tem_fused_0.run(primals_1, primals_2, primals_3, buf0, buf1, primals_5, primals_4, primals_7, primals_8, primals_6, buf2, 16, 8, 32, stream=stream0)
646
+ del buf1
647
+ return (buf2, primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, buf2, buf0, )
648
+
649
+ runner = Runner(partitions=[])
650
+ call = runner.call
651
+ recursively_apply_fns = runner.recursively_apply_fns
652
+
653
+
654
+ def benchmark_compiled_module(times=10, repeat=10):
655
+ from torch._dynamo.testing import rand_strided
656
+ from torch._inductor.utils import print_performance
657
+ primals_1 = rand_strided((8, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:0', dtype=torch.bfloat16)
658
+ primals_2 = rand_strided((8, 8, 2048, 128), (2097152, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16)
659
+ primals_3 = rand_strided((8, 8, 2048, 128), (2097152, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16)
660
+ primals_4 = rand_strided((8, 1, 16, 16), (256, 256, 16, 1), device='cuda:0', dtype=torch.int32)
661
+ primals_5 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:0', dtype=torch.int32)
662
+ primals_6 = rand_strided((8, ), (1, ), device='cuda:0', dtype=torch.int64)
663
+ primals_7 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:0', dtype=torch.int32)
664
+ primals_8 = rand_strided((8, 1, 16, 16), (256, 256, 16, 1), device='cuda:0', dtype=torch.int32)
665
+ primals_9 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:0', dtype=torch.int32)
666
+ primals_10 = rand_strided((8, 1, 16, 16), (256, 256, 16, 1), device='cuda:0', dtype=torch.int32)
667
+ primals_11 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:0', dtype=torch.int32)
668
+ primals_12 = rand_strided((8, 1, 16, 16), (256, 256, 16, 1), device='cuda:0', dtype=torch.int32)
669
+ fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12])
670
+ return print_performance(fn, times=times, repeat=repeat)
671
+
672
+
673
+ if __name__ == "__main__":
674
+ from torch._inductor.wrapper_benchmark import compiled_module_main
675
+ compiled_module_main('None', benchmark_compiled_module)
SpecForge-ext/cache/compiled_kernels/av/cavp7xan77tfr7qytfkp6sjrgkd6hvruiaqfzkeibtl5rtagscng.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.reduction(
11
+ size_hints={'x': 512, 'r0_': 16384},
12
+ reduction_hint=ReductionHint.INNER,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr1': '*i32', 'out_ptr2': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'ks4': 'i64', 'ks5': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
16
+ )
17
+ @triton.jit
18
+ def triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1(in_ptr0, out_ptr1, out_ptr2, ks0, ks1, ks2, ks3, ks4, ks5, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
19
+ r0_numel = 16384
20
+ rnumel = r0_numel
21
+ RBLOCK: tl.constexpr = R0_BLOCK
22
+ xoffset = tl.program_id(0) * XBLOCK
23
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
24
+ xmask = xindex < xnumel
25
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
26
+ rbase = r0_base
27
+ x1 = ((xindex // ks0) % ks1)
28
+ x0 = (xindex % ks0)
29
+ x2 = xindex // ks4
30
+ _tmp46 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64)
31
+ x5 = xindex
32
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
33
+ r0_index = r0_offset + r0_base
34
+ r0_mask = r0_index < r0_numel
35
+ roffset = r0_offset
36
+ rindex = r0_index
37
+ r0_4 = r0_index // 128
38
+ r0_3 = (r0_index % 128)
39
+ tmp0 = r0_4 + 128*x1
40
+ tmp1 = ks2
41
+ tmp2 = tmp0 < tmp1
42
+ tmp3 = r0_3 + 128*x0
43
+ tmp4 = ks3
44
+ tmp5 = tmp3 < tmp4
45
+ tmp6 = tmp2 & tmp5
46
+ tmp7 = r0_4 + 128*x1
47
+ tmp8 = r0_3 + 128*x0
48
+ tmp9 = tmp7 >= tmp8
49
+ tmp10 = tl.load(in_ptr0 + (tl.broadcast_to(x2, [XBLOCK, R0_BLOCK])), r0_mask & tmp6 & xmask, eviction_policy='evict_last', other=0.0)
50
+ tmp11 = tmp8 < tmp10
51
+ tmp12 = tmp7 < tmp10
52
+ tmp13 = tmp11 & tmp12
53
+ tmp14 = tmp9 & tmp13
54
+ tmp15 = tl.full([1, 1], False, tl.int1)
55
+ tmp16 = tmp15 | tmp14
56
+ tmp17 = tl.broadcast_to(ks5, [XBLOCK, R0_BLOCK])
57
+ tmp18 = tmp8 >= tmp17
58
+ tmp19 = (tmp8 % tmp17)
59
+ tmp20 = tl.full([1, 1], 0, tl.int32)
60
+ tmp21 = tmp19 != tmp20
61
+ tmp22 = (libdevice.signbit(tmp19) != 0) if (tmp19).dtype is tl.float32 else tmp19 < 0
62
+ tmp23 = (libdevice.signbit(tmp17) != 0) if (tmp17).dtype is tl.float32 else tmp17 < 0
63
+ tmp24 = tmp22 != tmp23
64
+ tmp25 = tmp21 & tmp24
65
+ tmp26 = tmp19 + tmp17
66
+ tmp27 = tl.where(tmp25, tmp26, tmp19)
67
+ tmp28 = tmp27 < tmp10
68
+ tmp29 = tmp18 & tmp28
69
+ tmp30 = r0_3 + ((-1)*r0_4) + ((-128)*x1) + 128*x0
70
+ tmp31 = (tmp30 % tmp17)
71
+ tmp32 = tmp31 != tmp20
72
+ tmp33 = (libdevice.signbit(tmp31) != 0) if (tmp31).dtype is tl.float32 else tmp31 < 0
73
+ tmp34 = tmp33 != tmp23
74
+ tmp35 = tmp32 & tmp34
75
+ tmp36 = tmp31 + tmp17
76
+ tmp37 = tl.where(tmp35, tmp36, tmp31)
77
+ tmp38 = tl.full([1, 1], 0, tl.int64)
78
+ tmp39 = tmp37 == tmp38
79
+ tmp40 = tmp29 & tmp39
80
+ tmp41 = tmp16 | tmp40
81
+ tmp42 = tl.full(tmp41.shape, False, tmp41.dtype)
82
+ tmp43 = tl.where(tmp6, tmp41, tmp42)
83
+ tmp44 = tmp43.to(tl.int64)
84
+ tmp45 = tl.broadcast_to(tmp44, [XBLOCK, R0_BLOCK])
85
+ tmp47 = _tmp46 + tmp45
86
+ _tmp46 = tl.where(r0_mask & xmask, tmp47, _tmp46)
87
+ tmp46 = tl.sum(_tmp46, 1)[:, None]
88
+ tmp48 = tl.full([1, 1], 0, tl.int64)
89
+ tmp49 = tmp46 > tmp48
90
+ tmp50 = tl.full([1, 1], 16384, tl.int64)
91
+ tmp51 = tmp46 < tmp50
92
+ tmp52 = tmp49 & tmp51
93
+ tmp53 = tmp52.to(tl.int8)
94
+ tmp54 = tmp53.to(tl.int32)
95
+ tmp55 = tmp46 == tmp50
96
+ tmp56 = tmp55.to(tl.int8)
97
+ tmp57 = tmp56.to(tl.int32)
98
+ tl.store(out_ptr1 + (x5), tmp54, xmask)
99
+ tl.store(out_ptr2 + (x5), tmp57, xmask)
SpecForge-ext/cache/compiled_kernels/bd/cbdpymknkquuerovirx6corahubfs5khfhys2add2b3c2zkuvlup.py ADDED
@@ -0,0 +1,835 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+
9
+ @triton_heuristics.template(
10
+
11
+ num_stages=3,
12
+ num_warps=8,
13
+ triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'kernel_name': 'triton_tem_fused_zeros_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
15
+
16
+ )
17
+ @triton.jit
18
+ def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0):
19
+ PRESCALE_QK : tl.constexpr = False
20
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
21
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
22
+ WRITE_DQ : tl.constexpr = True
23
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
24
+ OUTPUT_MAX : tl.constexpr = False
25
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
26
+ IS_DIVISIBLE : tl.constexpr = True
27
+ SM_SCALE : tl.constexpr = 0.08838834764831843
28
+ GQA_SHARED_HEADS : tl.constexpr = 4
29
+ HAS_FULL_BLOCKS : tl.constexpr = True
30
+ QK_HEAD_DIM : tl.constexpr = 128
31
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
32
+ V_HEAD_DIM : tl.constexpr = 128
33
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
34
+ SAFE_HEAD_DIM : tl.constexpr = True
35
+ BLOCK_M1 : tl.constexpr = 64
36
+ BLOCK_N1 : tl.constexpr = 128
37
+ BLOCK_M2 : tl.constexpr = 128
38
+ BLOCK_N2 : tl.constexpr = 64
39
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
40
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
41
+ INDEX_DTYPE : tl.constexpr = tl.int32
42
+ Q = arg_Q
43
+ K = arg_K
44
+ V = arg_V
45
+ LSE = arg_LSE
46
+ DELTA = arg_DELTA
47
+ DO = arg_DO
48
+ DQ = arg_DQ
49
+ DV = arg_DV
50
+ KV_NUM_BLKS = arg_KV_NUM_BLKS
51
+ KV_IDX = arg_KV_IDX
52
+ Q_NUM_BLKS = arg_Q_NUM_BLKS
53
+ Q_IDX = arg_Q_IDX
54
+ FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
55
+ FULL_KV_IDX = arg_FULL_KV_IDX
56
+ FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS
57
+ FULL_Q_IDX = arg_FULL_Q_IDX
58
+
59
+ # Sub notation for this kernel:
60
+ #
61
+ # Q: Query, K: Key, V: Value
62
+ # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype)
63
+ # DELTA: Precomputed sum(OUT*DO, axis=-1)
64
+ # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value
65
+ # DK: Derivative of Key, is the written to via the store_output call due to some limitations with
66
+ # inductor codegen
67
+ # M: Number of queries, N: Number of keys/values
68
+ # QK_HEAD_DIM: The dimension of the query and key embeddings
69
+ # V_HEAD_DIM: The dimension of the value embeddings
70
+ # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim
71
+ # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
72
+ # (Modifiable) Performance tuning options
73
+ # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block.
74
+ # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V.
75
+ # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q.
76
+ # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block.
77
+ #
78
+ # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
79
+ # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
80
+ # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
81
+ # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query.
82
+ # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query.
83
+ # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
84
+ # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
85
+ # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query.
86
+ # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query.
87
+
88
+ # The below are kernel options that can be applied for certain score_mods,
89
+ # or involve a numerics vs. perf tradeoff
90
+ # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
91
+ # about 20% more numerical error, but slightly faster.
92
+
93
+ # Define strides of inputs
94
+ stride_qz, stride_qh, stride_qm, stride_qd = 8388608, 128, 4096, 1
95
+ stride_kz, stride_kh, stride_kn, stride_kd = 2097152, 262144, 128, 1
96
+ stride_vz, stride_vh, stride_vn, stride_vd = 2097152, 262144, 128, 1
97
+ stride_doz, stride_doh, stride_dom, stride_dod = 8388608, 262144, 128, 1
98
+
99
+ stride_dqz, stride_dqh, stride_dqm, stride_dqd = 8388608, 128, 4096, 1
100
+ stride_dvz, stride_dvh, stride_dvm, stride_dvd = 2097152, 262144, 128, 1
101
+
102
+ ZQ = 8
103
+ HQ = 32
104
+ HKV = 8
105
+ Q_LEN = 2048
106
+ ZKV = 8
107
+ KV_LEN = 2048
108
+
109
+ MATMUL_PRECISION = Q.dtype.element_ty
110
+
111
+ pid = tl.program_id(0).to(INDEX_DTYPE)
112
+ NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1)
113
+ NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2)
114
+
115
+ off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx
116
+ off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx
117
+ off_zkv = off_zq % ZKV # kv batch idx
118
+
119
+ SPARSE_Z = 8
120
+ SPARSE_HQ = 1
121
+
122
+ sparse_idx_z = off_zq % SPARSE_Z
123
+
124
+ k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64)
125
+ v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64)
126
+ # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
127
+ # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
128
+ dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64)
129
+
130
+ # offset K, V, DV pointers for batch/kv-head
131
+ K += k_adj
132
+ V += v_adj
133
+ DV += dv_adj
134
+
135
+ RCP_LN2 = 1.44269504
136
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
137
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
138
+
139
+ if pid >= NUM_KV_BLOCKS:
140
+ off_pid = pid - NUM_KV_BLOCKS
141
+ # THIS BLOCK DOES DQ
142
+ SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2)
143
+ SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
144
+ off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS
145
+ start_m2_block = off_pid % NUM_Q_BLOCKS
146
+ off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE
147
+ stride_kv_num_blks_h = 16
148
+ stride_kv_idx_h = 256
149
+ stride_kv_idx_m = 16
150
+
151
+ sparse_idx_hq2 = off_hq2 % SPARSE_HQ
152
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2
153
+
154
+ sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask
155
+ sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950
156
+
157
+ # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
158
+ q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64)
159
+ do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64)
160
+ dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64)
161
+ off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64)
162
+
163
+ Q2 = Q + q_adj2
164
+ DO2 = DO + do_adj2
165
+ # TODO: This does not work if DQ is not the same layout as Q (for example,
166
+ # if Q is broadcasted)
167
+ DQ2 = DQ + dq_adj2
168
+ LSE2 = LSE + off_chz2
169
+ DELTA2 = DELTA + off_chz2
170
+
171
+ # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32)
172
+ dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
173
+
174
+ start_m2 = start_m2_block * BLOCK_M2
175
+ offs_m2 = start_m2 + tl.arange(0, BLOCK_M2)
176
+
177
+ # load Q and do: they stay in SRAM throughout the inner loop.
178
+ q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
179
+ do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
180
+
181
+ if PRESCALE_QK:
182
+ q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
183
+
184
+ if IS_DIVISIBLE:
185
+ Di = tl.load(DELTA2 + offs_m2)
186
+ lse = tl.load(LSE2 + offs_m2)
187
+ else:
188
+ Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN)
189
+ lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN)
190
+ lse = tl.where(lse == -float("inf"), 0.0, lse)
191
+ lse = lse[:, None]
192
+
193
+ # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
194
+ # KV_IDX and KV_NUM_BLKS are always contiguous.
195
+ kv_indices = KV_IDX + sparse_kv_idx_offset
196
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
197
+ sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
198
+
199
+ offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
200
+ dq = bwd_dq_inner(
201
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
202
+ K, V,
203
+ dq, q, do, Di, lse,
204
+ off_zq, off_hq2, offs_m2, offs_n2,
205
+ stride_kn, stride_kd, stride_vn, stride_vd,
206
+ kv_indices, sparse_kv_num_blocks,
207
+ MATMUL_PRECISION,
208
+ IS_FULL_BLOCKS=False,
209
+ )
210
+
211
+ if HAS_FULL_BLOCKS:
212
+ # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
213
+ # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
214
+ kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
215
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
216
+ sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
217
+
218
+ offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
219
+ dq = bwd_dq_inner(
220
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
221
+ K, V,
222
+ dq, q, do, Di, lse,
223
+ off_zq, off_hq2, offs_m2, offs_n2,
224
+ stride_kn, stride_kd, stride_vn, stride_vd,
225
+ kv_indices, sparse_kv_num_blocks,
226
+ MATMUL_PRECISION,
227
+ IS_FULL_BLOCKS=True,
228
+ )
229
+
230
+ # Write back dQ.
231
+ dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd
232
+ dq *= SM_SCALE
233
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
234
+ tl.store(dq_ptrs, dq)
235
+ else:
236
+ tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM))
237
+ else:
238
+ # THIS BLOCK DOES DK & DV
239
+ SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
240
+ SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1)
241
+
242
+ pid_mask = pid // SPARSE_KV_MULTIPLE
243
+
244
+ stride_q_num_blks_h = 16
245
+ stride_q_idx_h = 256
246
+ stride_q_idx_n = 16
247
+
248
+
249
+ dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
250
+ dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
251
+
252
+ start_n1 = pid * BLOCK_N1
253
+ offs_n1 = start_n1 + tl.arange(0, BLOCK_N1)
254
+
255
+ # load K and V: they stay in SRAM throughout the inner loop.
256
+ k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
257
+ v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
258
+
259
+ if PRESCALE_QK:
260
+ k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
261
+
262
+ for off_g in range(0, GQA_SHARED_HEADS):
263
+ off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g
264
+
265
+ # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
266
+ q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64)
267
+ do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64)
268
+ dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64)
269
+ off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64)
270
+
271
+ Q1 = Q + q_adj1
272
+ DO1 = DO + do_adj1
273
+ # TODO: This does not work if DQ is not the same layout as Q (for example,
274
+ # if Q is broadcasted)
275
+ LSE1 = LSE + off_chz1
276
+ DELTA1 = DELTA + off_chz1
277
+
278
+ sparse_idx_hq1 = off_hq1 % SPARSE_HQ
279
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1
280
+
281
+ sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask
282
+ sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950
283
+
284
+ # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
285
+ # Q_IDX and Q_NUM_BLKS are always contiguous.
286
+ q_indices = Q_IDX + sparse_q_idx_offset
287
+ q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
288
+ sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset)
289
+
290
+ offs_m1 = q_start + tl.arange(0, BLOCK_M1)
291
+ dk, dv = bwd_dkdv_inner(
292
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
293
+ Q1, DO1, DELTA1, LSE1,
294
+ dk, dv, k, v,
295
+ off_zq, off_hq1, offs_n1, offs_m1,
296
+ stride_qm, stride_qd, stride_dom, stride_dod,
297
+ q_indices, sparse_q_num_blocks,
298
+ MATMUL_PRECISION,
299
+ IS_FULL_BLOCKS=False,
300
+ )
301
+
302
+
303
+ if HAS_FULL_BLOCKS:
304
+ # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
305
+ # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous.
306
+ q_indices = FULL_Q_IDX + sparse_q_idx_offset
307
+ q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
308
+ sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset)
309
+
310
+ offs_m1 = q_start + tl.arange(0, BLOCK_M1)
311
+ dk, dv = bwd_dkdv_inner(
312
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
313
+ Q1, DO1, DELTA1, LSE1,
314
+ dk, dv, k, v,
315
+ off_zq, off_hq1, offs_n1, offs_m1,
316
+ stride_qm, stride_qd, stride_dom, stride_dod,
317
+ q_indices, sparse_q_num_blocks,
318
+ MATMUL_PRECISION,
319
+ IS_FULL_BLOCKS=True,
320
+ )
321
+
322
+ # Write back dV and dK.
323
+ dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd
324
+
325
+ index_n = offs_n1[:, None]
326
+ index_k = offs_k[None, :]
327
+ index_v = offs_v[None, :]
328
+
329
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
330
+ tl.store(dv_ptrs, dv)
331
+ else:
332
+ tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM))
333
+
334
+ dk *= SM_SCALE
335
+
336
+ if SAFE_HEAD_DIM:
337
+ mask = index_n < KV_LEN
338
+ else:
339
+ mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM)
340
+
341
+ # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
342
+ # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
343
+ tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED])
344
+ xindex = index_k + 128*index_n + 262144*off_hkv + 2097152*off_zq
345
+ tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask)
346
+
347
+ @triton.jit
348
+ def bwd_dq_inner(
349
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
350
+ K, V, # pointers
351
+ dq, q, do, Di, lse,
352
+ off_z, off_hq, offs_m2, offs_n2,
353
+ stride_kn, stride_kd, stride_vn, stride_vd,
354
+ kv_indices, sparse_kv_num_blocks,
355
+ MATMUL_PRECISION,
356
+ IS_FULL_BLOCKS,
357
+ ):
358
+ PRESCALE_QK : tl.constexpr = False
359
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
360
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
361
+ WRITE_DQ : tl.constexpr = True
362
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
363
+ OUTPUT_MAX : tl.constexpr = False
364
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
365
+ IS_DIVISIBLE : tl.constexpr = True
366
+ SM_SCALE : tl.constexpr = 0.08838834764831843
367
+ GQA_SHARED_HEADS : tl.constexpr = 4
368
+ HAS_FULL_BLOCKS : tl.constexpr = True
369
+ QK_HEAD_DIM : tl.constexpr = 128
370
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
371
+ V_HEAD_DIM : tl.constexpr = 128
372
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
373
+ SAFE_HEAD_DIM : tl.constexpr = True
374
+ BLOCK_M1 : tl.constexpr = 64
375
+ BLOCK_N1 : tl.constexpr = 128
376
+ BLOCK_M2 : tl.constexpr = 128
377
+ BLOCK_N2 : tl.constexpr = 64
378
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
379
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
380
+ INDEX_DTYPE : tl.constexpr = tl.int32
381
+
382
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
383
+ RCP_LN2: tl.constexpr = 1.44269504
384
+ Q_LEN = 2048
385
+ KV_LEN = 2048
386
+
387
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
388
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
389
+
390
+ kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd
391
+ vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd
392
+ # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
393
+ tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
394
+
395
+ hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1))
396
+
397
+ for start_n in range(0, hi):
398
+ dq = bwd_dq_block_mn(
399
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
400
+ dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
401
+ off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
402
+ stride_kn, stride_kd, stride_vn, stride_vd,
403
+ kv_indices, sparse_kv_num_blocks,
404
+ MATMUL_PRECISION, RCP_LN2,
405
+ IS_FULL_BLOCKS,
406
+ )
407
+
408
+ # Increment pointers.
409
+ offset = get_offset_for_next_block(
410
+ start_n, kv_indices, sparse_kv_num_blocks,
411
+ SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS
412
+ )
413
+
414
+ kT_ptrs += offset * stride_kn
415
+ vT_ptrs += offset * stride_vn
416
+
417
+ offs_n2 += offset
418
+
419
+ return dq
420
+
421
+
422
+ @triton.jit
423
+ def bwd_dq_block_mn(
424
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
425
+ dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
426
+ off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
427
+ stride_kn, stride_kd, stride_vn, stride_vd,
428
+ kv_indices, sparse_kv_num_blocks,
429
+ MATMUL_PRECISION, RCP_LN2,
430
+ IS_FULL_BLOCKS,
431
+ ):
432
+ PRESCALE_QK : tl.constexpr = False
433
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
434
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
435
+ WRITE_DQ : tl.constexpr = True
436
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
437
+ OUTPUT_MAX : tl.constexpr = False
438
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
439
+ IS_DIVISIBLE : tl.constexpr = True
440
+ SM_SCALE : tl.constexpr = 0.08838834764831843
441
+ GQA_SHARED_HEADS : tl.constexpr = 4
442
+ HAS_FULL_BLOCKS : tl.constexpr = True
443
+ QK_HEAD_DIM : tl.constexpr = 128
444
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
445
+ V_HEAD_DIM : tl.constexpr = 128
446
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
447
+ SAFE_HEAD_DIM : tl.constexpr = True
448
+ BLOCK_M1 : tl.constexpr = 64
449
+ BLOCK_N1 : tl.constexpr = 128
450
+ BLOCK_M2 : tl.constexpr = 128
451
+ BLOCK_N2 : tl.constexpr = 64
452
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
453
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
454
+ INDEX_DTYPE : tl.constexpr = tl.int32
455
+
456
+
457
+ # NB reversed order to since K is transposed
458
+ kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN)
459
+ qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION)
460
+ if not PRESCALE_QK:
461
+ qk *= SM_SCALE
462
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
463
+ pre_mod_scores = qk
464
+ n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None)
465
+ # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim
466
+ # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary
467
+ m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None)
468
+
469
+ tmp0 = (qk)
470
+ post_mod_scores = tmp0
471
+
472
+
473
+
474
+
475
+ if not IS_DIVISIBLE:
476
+ post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf"))
477
+
478
+ if not IS_FULL_BLOCKS:
479
+ tmp1 = tl.full([1], False, tl.int1)
480
+ tmp2 = (m)
481
+ tmp3 = (n)
482
+ tmp4 = tmp2 >= tmp3
483
+ tmp5 = tmp3.to(tl.int64)
484
+ tmp6 = (off_z)
485
+ tmp7 = tl.load(in_ptr16 + tmp6)
486
+ tmp8 = tmp5 < tmp7
487
+ tmp9 = tmp2.to(tl.int64)
488
+ tmp10 = tmp9 < tmp7
489
+ tmp11 = tmp8 & tmp10
490
+ tmp12 = tmp4 & tmp11
491
+ tmp13 = tmp1 | tmp12
492
+ tmp14 = tl.full([1], 2048, tl.int32)
493
+ tmp15 = tmp3 >= tmp14
494
+ tmp16 = (tmp3 % tmp14)
495
+ tmp17 = tl.full([1], 0, tl.int32)
496
+ tmp18 = tmp16 != tmp17
497
+ tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0
498
+ tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0
499
+ tmp21 = tmp19 != tmp20
500
+ tmp22 = tmp18 & tmp21
501
+ tmp23 = tmp16 + tmp14
502
+ tmp24 = tl.where(tmp22, tmp23, tmp16)
503
+ tmp25 = tmp24.to(tl.int64)
504
+ tmp26 = tmp25 < tmp7
505
+ tmp27 = tmp15 & tmp26
506
+ tmp28 = tmp3 - tmp2
507
+ tmp29 = (tmp28 % tmp14)
508
+ tmp30 = tmp29 != tmp17
509
+ tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0
510
+ tmp32 = tmp31 != tmp20
511
+ tmp33 = tmp30 & tmp32
512
+ tmp34 = tmp29 + tmp14
513
+ tmp35 = tl.where(tmp33, tmp34, tmp29)
514
+ tmp36 = tmp35 == tmp17
515
+ tmp37 = tmp27 & tmp36
516
+ tmp38 = tmp13 | tmp37
517
+ mask_mod_output = tmp38
518
+
519
+
520
+ # apply mask for partial masked block
521
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
522
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
523
+ if not PRESCALE_QK:
524
+ post_mod_scores *= RCP_LN2
525
+ p = tl.math.exp2(post_mod_scores - lse)
526
+ # Compute dP and dS.
527
+ # NB reversed order to since V is transposed
528
+ vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN)
529
+
530
+ dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION)
531
+ ds = p * (dp - Di[:, None])
532
+ # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
533
+ tmp39 = (ds)
534
+ grad_scores = tmp39
535
+
536
+
537
+ if not IS_DIVISIBLE:
538
+ grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0)
539
+
540
+ # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
541
+ if WRITE_DQ:
542
+ scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN)
543
+
544
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
545
+ ds = grad_scores
546
+
547
+ if not IS_FULL_BLOCKS:
548
+ # (grads) apply mask for partially unmasked block
549
+ ds = tl.where(mask_mod_output, ds, 0.0)
550
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
551
+ ds = ds.to(MATMUL_PRECISION)
552
+ # Compute dQ.
553
+ dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION)
554
+
555
+ return dq
556
+
557
+
558
+ @triton.jit
559
+ def bwd_dkdv_inner(
560
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
561
+ Q, DO, DELTA, LSE, # pointers
562
+ dk, dv, k, v,
563
+ off_z, off_hq, offs_n1, offs_m1,
564
+ stride_qm, stride_qd, stride_dom, stride_dod,
565
+ q_indices, sparse_q_num_blocks,
566
+ MATMUL_PRECISION,
567
+ IS_FULL_BLOCKS,
568
+ ):
569
+ PRESCALE_QK : tl.constexpr = False
570
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
571
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
572
+ WRITE_DQ : tl.constexpr = True
573
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
574
+ OUTPUT_MAX : tl.constexpr = False
575
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
576
+ IS_DIVISIBLE : tl.constexpr = True
577
+ SM_SCALE : tl.constexpr = 0.08838834764831843
578
+ GQA_SHARED_HEADS : tl.constexpr = 4
579
+ HAS_FULL_BLOCKS : tl.constexpr = True
580
+ QK_HEAD_DIM : tl.constexpr = 128
581
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
582
+ V_HEAD_DIM : tl.constexpr = 128
583
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
584
+ SAFE_HEAD_DIM : tl.constexpr = True
585
+ BLOCK_M1 : tl.constexpr = 64
586
+ BLOCK_N1 : tl.constexpr = 128
587
+ BLOCK_M2 : tl.constexpr = 128
588
+ BLOCK_N2 : tl.constexpr = 64
589
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
590
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
591
+ INDEX_DTYPE : tl.constexpr = tl.int32
592
+
593
+ SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
594
+ RCP_LN2: tl.constexpr = 1.44269504
595
+ Q_LEN = 2048
596
+ KV_LEN = 2048
597
+
598
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
599
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
600
+
601
+ qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd
602
+ do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod
603
+ # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
604
+ tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
605
+
606
+ # The minimum is needed to handle the case where we run with a super large
607
+ # SPARSE_BLOCK_SIZE (i.e. no block-mask!)
608
+ hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1))
609
+
610
+ for start_m in range(0, hi):
611
+ dk, dv = bwd_dkdv_block_mn(
612
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
613
+ dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
614
+ off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
615
+ stride_qm, stride_qd, stride_dom, stride_dod,
616
+ q_indices, sparse_q_num_blocks,
617
+ MATMUL_PRECISION, RCP_LN2,
618
+ IS_FULL_BLOCKS,
619
+ )
620
+ # Increment pointers.
621
+ offset = get_offset_for_next_block(
622
+ start_m, q_indices, sparse_q_num_blocks,
623
+ SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS
624
+ )
625
+
626
+ qT_ptrs += offset * stride_qm
627
+ do_ptrs += offset * stride_dom
628
+ offs_m1 += offset
629
+
630
+ return dk, dv
631
+
632
+
633
+ @triton.jit
634
+ def bwd_dkdv_block_mn(
635
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
636
+ dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
637
+ off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
638
+ stride_qm, stride_qd, stride_dom, stride_dod,
639
+ q_indices, sparse_q_num_blocks,
640
+ MATMUL_PRECISION, RCP_LN2,
641
+ IS_FULL_BLOCKS,
642
+ ):
643
+ PRESCALE_QK : tl.constexpr = False
644
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
645
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
646
+ WRITE_DQ : tl.constexpr = True
647
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
648
+ OUTPUT_MAX : tl.constexpr = False
649
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
650
+ IS_DIVISIBLE : tl.constexpr = True
651
+ SM_SCALE : tl.constexpr = 0.08838834764831843
652
+ GQA_SHARED_HEADS : tl.constexpr = 4
653
+ HAS_FULL_BLOCKS : tl.constexpr = True
654
+ QK_HEAD_DIM : tl.constexpr = 128
655
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
656
+ V_HEAD_DIM : tl.constexpr = 128
657
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
658
+ SAFE_HEAD_DIM : tl.constexpr = True
659
+ BLOCK_M1 : tl.constexpr = 64
660
+ BLOCK_N1 : tl.constexpr = 128
661
+ BLOCK_M2 : tl.constexpr = 128
662
+ BLOCK_N2 : tl.constexpr = 64
663
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
664
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
665
+ INDEX_DTYPE : tl.constexpr = tl.int32
666
+
667
+
668
+ # NB reversed order since Q is transposed
669
+ qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN)
670
+ # Load LSE before computing qk to reduce pipeline stall.
671
+ if IS_DIVISIBLE:
672
+ lse = tl.load(LSE + offs_m1)
673
+ else:
674
+ lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN)
675
+ lse = tl.where(lse == -float("inf"), 0.0, lse)
676
+ qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION)
677
+ if not PRESCALE_QK:
678
+ qkT *= SM_SCALE
679
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
680
+ m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None)
681
+ # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim
682
+ # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary
683
+ n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None)
684
+
685
+ pre_mod_scores = qkT
686
+ tmp40 = (qkT)
687
+ post_mod_scores = tmp40
688
+
689
+
690
+
691
+ if not IS_DIVISIBLE:
692
+ post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf"))
693
+
694
+ if not IS_FULL_BLOCKS:
695
+ tmp41 = tl.full([1], False, tl.int1)
696
+ tmp42 = (m)
697
+ tmp43 = (n)
698
+ tmp44 = tmp42 >= tmp43
699
+ tmp45 = tmp43.to(tl.int64)
700
+ tmp46 = (off_z)
701
+ tmp47 = tl.load(in_ptr16 + tmp46)
702
+ tmp48 = tmp45 < tmp47
703
+ tmp49 = tmp42.to(tl.int64)
704
+ tmp50 = tmp49 < tmp47
705
+ tmp51 = tmp48 & tmp50
706
+ tmp52 = tmp44 & tmp51
707
+ tmp53 = tmp41 | tmp52
708
+ tmp54 = tl.full([1], 2048, tl.int32)
709
+ tmp55 = tmp43 >= tmp54
710
+ tmp56 = (tmp43 % tmp54)
711
+ tmp57 = tl.full([1], 0, tl.int32)
712
+ tmp58 = tmp56 != tmp57
713
+ tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0
714
+ tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0
715
+ tmp61 = tmp59 != tmp60
716
+ tmp62 = tmp58 & tmp61
717
+ tmp63 = tmp56 + tmp54
718
+ tmp64 = tl.where(tmp62, tmp63, tmp56)
719
+ tmp65 = tmp64.to(tl.int64)
720
+ tmp66 = tmp65 < tmp47
721
+ tmp67 = tmp55 & tmp66
722
+ tmp68 = tmp43 - tmp42
723
+ tmp69 = (tmp68 % tmp54)
724
+ tmp70 = tmp69 != tmp57
725
+ tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0
726
+ tmp72 = tmp71 != tmp60
727
+ tmp73 = tmp70 & tmp72
728
+ tmp74 = tmp69 + tmp54
729
+ tmp75 = tl.where(tmp73, tmp74, tmp69)
730
+ tmp76 = tmp75 == tmp57
731
+ tmp77 = tmp67 & tmp76
732
+ tmp78 = tmp53 | tmp77
733
+ mask_mod_output = tmp78
734
+
735
+ # (grads) apply mask for fully masked block
736
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
737
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
738
+ if not PRESCALE_QK:
739
+ post_mod_scores *= RCP_LN2
740
+ pT = tl.math.exp2(post_mod_scores - lse[None, :])
741
+ do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
742
+ # Compute dV.
743
+ ppT = pT
744
+ dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION)
745
+ if IS_DIVISIBLE:
746
+ Di = tl.load(DELTA + offs_m1)
747
+ else:
748
+ Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN)
749
+ # Compute dP and dS.
750
+ dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION)
751
+ dsT = pT * (dpT - Di[None, :])
752
+ # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
753
+ tmp79 = (dsT)
754
+ grad_scores = tmp79
755
+
756
+
757
+
758
+ if not IS_DIVISIBLE:
759
+ grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0)
760
+
761
+ # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
762
+ if not WRITE_DQ:
763
+ idx_b = off_z
764
+ idx_h = off_hq
765
+ idx_m = m
766
+ idx_n = n
767
+ scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN)
768
+
769
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
770
+ dsT = grad_scores
771
+ if not IS_FULL_BLOCKS:
772
+ # (grads) apply mask for partially unmasked block
773
+ dsT = tl.where(mask_mod_output, dsT, 0.0)
774
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
775
+ dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION)
776
+
777
+ return dk, dv
778
+
779
+ # Utility triton funcs
780
+ @triton.jit
781
+ def get_offset_for_next_block(
782
+ loop_iter, col_indices, total_blocks,
783
+ SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
784
+ BLOCKS_ARE_CONTIGUOUS: tl.constexpr
785
+ ):
786
+ if BLOCKS_ARE_CONTIGUOUS:
787
+ return BLOCK
788
+ cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
789
+ cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
790
+ next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
791
+ needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
792
+ jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
793
+ offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
794
+ return offset
795
+
796
+ @triton.jit
797
+ def get_bounded_indices(indices, max_len=None):
798
+ return indices % max_len if max_len is not None else indices
799
+
800
+ @triton.jit
801
+ def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
802
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
803
+ return tl.load(block_ptr)
804
+ elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
805
+ return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
806
+ elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
807
+ return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
808
+ else:
809
+ return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
810
+
811
+ @triton.jit
812
+ def load_checked_2d(
813
+ ptr,
814
+ offs_m,
815
+ offs_n,
816
+ stride_m,
817
+ stride_n,
818
+ IS_DIVISIBLE_M: tl.constexpr,
819
+ IS_DIVISIBLE_N: tl.constexpr,
820
+ M_LEN: tl.constexpr,
821
+ N_LEN: tl.constexpr,
822
+ ):
823
+ # Calculate final pointer if strides are provided
824
+ if stride_m is not None and stride_n is not None:
825
+ ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
826
+
827
+ # Handle all masking cases
828
+ if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
829
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
830
+ elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
831
+ return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
832
+ elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
833
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
834
+ else: # Both divisible
835
+ return tl.load(ptr)
SpecForge-ext/cache/compiled_kernels/bi/8786fd641e91216a3bc7781055fbc9277e1637f9f319eaed8124e438ba94886f.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 1, "num_warps": 2, "num_stages": 1, "configs_hash": "b6ac5ef64fddcad8fc8d2c05fa12424871fd9baa5a4158ff38ecebbafb55a4b1", "found_by_coordesc": false, "time_taken_ms": 26, "triton_cache_hash": "E2MI47QNGZ2SJDA3U3EKHN7H3EYRAANF6T7N5SFT2CZJYNBAWCNQ"}