nologik commited on
Commit
64de0bd
·
verified ·
1 Parent(s): 057d8be

Uploaded using `kernel-builder`.

Browse files
build/torch210-cxx11-cu130-aarch64-linux/__init__.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: AGPL-3.0-only
2
+ """Atlas Gated DeltaNet kernels for NVIDIA GB10 (SM121).
3
+
4
+ These kernels back the linear-attention path of Qwen3.6 hybrid models
5
+ (27B dense and 35B-A3B sparse). They are hand-tuned for the unified
6
+ LPDDR5X memory layout of the DGX Spark and pinned to compute
7
+ capability 12.1 — they will not load on any other GPU.
8
+ """
9
+
10
+ from typing import Optional
11
+
12
+ import torch
13
+
14
+ from ._ops import ops
15
+
16
+ __all__ = [
17
+ "gdn_decode",
18
+ "gdn_prefill",
19
+ "gdn_chunk2",
20
+ "gdn_chunk3",
21
+ "gdn_wy2",
22
+ "gdn_wy3",
23
+ "gdn_wy4",
24
+ "causal_conv1d_fwd",
25
+ "causal_conv1d_update",
26
+ ]
27
+
28
+
29
+ def gdn_decode(
30
+ h_state: torch.Tensor,
31
+ query: torch.Tensor,
32
+ key: torch.Tensor,
33
+ value: torch.Tensor,
34
+ gate: torch.Tensor,
35
+ beta: torch.Tensor,
36
+ output: torch.Tensor,
37
+ ) -> None:
38
+ """Single-token GDN decode (in-place update of ``h_state`` and ``output``).
39
+
40
+ The recurrent path keeps Q/K/V in FP32 to avoid the precision drift
41
+ that BF16 inputs cause over long contexts in hybrid models.
42
+
43
+ Shapes
44
+ ------
45
+ h_state : (B, num_v_heads, k_dim, v_dim) float32, in-place updated
46
+ query : (B, num_k_heads, k_dim) float32
47
+ key : (B, num_k_heads, k_dim) float32
48
+ value : (B, num_v_heads, v_dim) float32
49
+ gate : (B, num_v_heads) float32 (exp(g_t) decay)
50
+ beta : (B, num_v_heads) float32 (sigmoid(b_t))
51
+ output : (B, num_v_heads, v_dim) bfloat16, in-place written
52
+ """
53
+ ops.gdn_decode(h_state, query, key, value, gate, beta, output)
54
+
55
+
56
+ def gdn_prefill(
57
+ h_state: torch.Tensor,
58
+ query: torch.Tensor,
59
+ key: torch.Tensor,
60
+ value: torch.Tensor,
61
+ gate: torch.Tensor,
62
+ beta: torch.Tensor,
63
+ output: torch.Tensor,
64
+ ) -> None:
65
+ """Multi-token GDN prefill (one batch, one chunk).
66
+
67
+ Shapes
68
+ ------
69
+ h_state : (B, num_v_heads, k_dim, v_dim) float32
70
+ query : (B, seq_len, num_k_heads, k_dim) bfloat16
71
+ key : (B, seq_len, num_k_heads, k_dim) bfloat16
72
+ value : (B, seq_len, num_v_heads, v_dim) bfloat16
73
+ gate : (B, seq_len, num_v_heads) float32
74
+ beta : (B, seq_len, num_v_heads) float32
75
+ output : (B, seq_len, num_v_heads, v_dim) bfloat16
76
+ """
77
+ ops.gdn_prefill(h_state, query, key, value, gate, beta, output)
78
+
79
+
80
+ def gdn_chunk2(
81
+ h_state: torch.Tensor, query: torch.Tensor, key: torch.Tensor,
82
+ value: torch.Tensor, gate: torch.Tensor, beta: torch.Tensor,
83
+ output: torch.Tensor, h_state_intermediate: torch.Tensor,
84
+ ) -> None:
85
+ """K=2 chunkwise verify (MTP draft length 1).
86
+
87
+ Writes the intermediate state after token 0 to ``h_state_intermediate``
88
+ so the caller can roll back when token 1 is rejected.
89
+ """
90
+ ops.gdn_chunk2(h_state, query, key, value, gate, beta, output, h_state_intermediate)
91
+
92
+
93
+ def gdn_chunk3(
94
+ h_state: torch.Tensor, query: torch.Tensor, key: torch.Tensor,
95
+ value: torch.Tensor, gate: torch.Tensor, beta: torch.Tensor,
96
+ output: torch.Tensor,
97
+ h_state_inter0: torch.Tensor, h_state_inter1: torch.Tensor,
98
+ ) -> None:
99
+ """K=3 chunkwise verify (MTP draft length 2)."""
100
+ ops.gdn_chunk3(h_state, query, key, value, gate, beta, output,
101
+ h_state_inter0, h_state_inter1)
102
+
103
+
104
+ def gdn_wy2(
105
+ h_state: torch.Tensor, query: torch.Tensor, key: torch.Tensor,
106
+ value: torch.Tensor, gate: torch.Tensor, beta: torch.Tensor,
107
+ output: torch.Tensor, h_state_intermediate: torch.Tensor,
108
+ ) -> None:
109
+ """2-pass WY-chunkwise K=2 verify (replaces chunk2 at higher acceptance)."""
110
+ ops.gdn_wy2(h_state, query, key, value, gate, beta, output, h_state_intermediate)
111
+
112
+
113
+ def gdn_wy3(
114
+ h_state: torch.Tensor, query: torch.Tensor, key: torch.Tensor,
115
+ value: torch.Tensor, gate: torch.Tensor, beta: torch.Tensor,
116
+ output: torch.Tensor, h_state_inter0: torch.Tensor, h_state_inter1: torch.Tensor,
117
+ ) -> None:
118
+ """2-pass WY-chunkwise K=3 verify."""
119
+ ops.gdn_wy3(h_state, query, key, value, gate, beta, output,
120
+ h_state_inter0, h_state_inter1)
121
+
122
+
123
+ def gdn_wy4(
124
+ h_state: torch.Tensor, query: torch.Tensor, key: torch.Tensor,
125
+ value: torch.Tensor, gate: torch.Tensor, beta: torch.Tensor,
126
+ output: torch.Tensor,
127
+ h_state_inter0: torch.Tensor,
128
+ h_state_inter1: torch.Tensor,
129
+ h_state_inter2: torch.Tensor,
130
+ ) -> None:
131
+ """2-pass WY-chunkwise K=4 verify."""
132
+ ops.gdn_wy4(h_state, query, key, value, gate, beta, output,
133
+ h_state_inter0, h_state_inter1, h_state_inter2)
134
+
135
+
136
+ def causal_conv1d_fwd(
137
+ x: torch.Tensor,
138
+ weight: torch.Tensor,
139
+ bias: Optional[torch.Tensor],
140
+ out: torch.Tensor,
141
+ ) -> None:
142
+ """Depthwise causal Conv1d forward (used by the SSM input projection).
143
+
144
+ x : (B, D, L) bfloat16
145
+ weight : (D, d_conv) bfloat16
146
+ bias : (D,) float32 or None
147
+ out : (B, D, L) bfloat16
148
+ """
149
+ ops.causal_conv1d_fwd(x, weight, bias, out)
150
+
151
+
152
+ def causal_conv1d_update(
153
+ conv_state: torch.Tensor,
154
+ x: torch.Tensor,
155
+ weight: torch.Tensor,
156
+ bias: Optional[torch.Tensor],
157
+ out: torch.Tensor,
158
+ ) -> None:
159
+ """Single-step causal Conv1d update (single-token decode path).
160
+
161
+ conv_state : (B, D, d_conv) float32, in-place updated (rolled left, last slot = x)
162
+ x : (B, D) bfloat16 (new input)
163
+ weight : (D, d_conv) bfloat16
164
+ bias : (D,) float32 or None
165
+ out : (B, D) bfloat16
166
+ """
167
+ ops.causal_conv1d_update(conv_state, x, weight, bias, out)
build/torch210-cxx11-cu130-aarch64-linux/_gdn_cuda_bcba8ad.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8c28c8e0849b7ebff1a86d47b56b8036b08c654b94f7489f080d5d5b98ac3091
3
+ size 3242624
build/torch210-cxx11-cu130-aarch64-linux/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _gdn_cuda_bcba8ad
3
+ ops = torch.ops._gdn_cuda_bcba8ad
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_gdn_cuda_bcba8ad::{op_name}"
build/torch210-cxx11-cu130-aarch64-linux/gdn/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ctypes
2
+ import importlib.util
3
+ import sys
4
+ from pathlib import Path
5
+ from types import ModuleType
6
+
7
+
8
+ def _import_from_path(file_path: Path) -> ModuleType:
9
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
10
+ # it would also be used for other imports. So, we make a module name that
11
+ # depends on the path for it to be unique using the hex-encoded hash of
12
+ # the path.
13
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
14
+ module_name = path_hash
15
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
16
+ if spec is None:
17
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
18
+ module = importlib.util.module_from_spec(spec)
19
+ if module is None:
20
+ raise ImportError(f"Cannot load module {module_name} from spec")
21
+ sys.modules[module_name] = module
22
+ spec.loader.exec_module(module) # type: ignore
23
+ return module
24
+
25
+
26
+ globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
build/torch210-cxx11-cu130-aarch64-linux/metadata.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "gdn",
3
+ "id": "_gdn_cuda_bcba8ad",
4
+ "version": 0,
5
+ "license": "AGPL-3.0-only",
6
+ "python-depends": [],
7
+ "backend": {
8
+ "type": "cuda"
9
+ }
10
+ }
build/torch211-cxx11-cu130-aarch64-linux/__init__.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: AGPL-3.0-only
2
+ """Atlas Gated DeltaNet kernels for NVIDIA GB10 (SM121).
3
+
4
+ These kernels back the linear-attention path of Qwen3.6 hybrid models
5
+ (27B dense and 35B-A3B sparse). They are hand-tuned for the unified
6
+ LPDDR5X memory layout of the DGX Spark and pinned to compute
7
+ capability 12.1 — they will not load on any other GPU.
8
+ """
9
+
10
+ from typing import Optional
11
+
12
+ import torch
13
+
14
+ from ._ops import ops
15
+
16
+ __all__ = [
17
+ "gdn_decode",
18
+ "gdn_prefill",
19
+ "gdn_chunk2",
20
+ "gdn_chunk3",
21
+ "gdn_wy2",
22
+ "gdn_wy3",
23
+ "gdn_wy4",
24
+ "causal_conv1d_fwd",
25
+ "causal_conv1d_update",
26
+ ]
27
+
28
+
29
+ def gdn_decode(
30
+ h_state: torch.Tensor,
31
+ query: torch.Tensor,
32
+ key: torch.Tensor,
33
+ value: torch.Tensor,
34
+ gate: torch.Tensor,
35
+ beta: torch.Tensor,
36
+ output: torch.Tensor,
37
+ ) -> None:
38
+ """Single-token GDN decode (in-place update of ``h_state`` and ``output``).
39
+
40
+ The recurrent path keeps Q/K/V in FP32 to avoid the precision drift
41
+ that BF16 inputs cause over long contexts in hybrid models.
42
+
43
+ Shapes
44
+ ------
45
+ h_state : (B, num_v_heads, k_dim, v_dim) float32, in-place updated
46
+ query : (B, num_k_heads, k_dim) float32
47
+ key : (B, num_k_heads, k_dim) float32
48
+ value : (B, num_v_heads, v_dim) float32
49
+ gate : (B, num_v_heads) float32 (exp(g_t) decay)
50
+ beta : (B, num_v_heads) float32 (sigmoid(b_t))
51
+ output : (B, num_v_heads, v_dim) bfloat16, in-place written
52
+ """
53
+ ops.gdn_decode(h_state, query, key, value, gate, beta, output)
54
+
55
+
56
+ def gdn_prefill(
57
+ h_state: torch.Tensor,
58
+ query: torch.Tensor,
59
+ key: torch.Tensor,
60
+ value: torch.Tensor,
61
+ gate: torch.Tensor,
62
+ beta: torch.Tensor,
63
+ output: torch.Tensor,
64
+ ) -> None:
65
+ """Multi-token GDN prefill (one batch, one chunk).
66
+
67
+ Shapes
68
+ ------
69
+ h_state : (B, num_v_heads, k_dim, v_dim) float32
70
+ query : (B, seq_len, num_k_heads, k_dim) bfloat16
71
+ key : (B, seq_len, num_k_heads, k_dim) bfloat16
72
+ value : (B, seq_len, num_v_heads, v_dim) bfloat16
73
+ gate : (B, seq_len, num_v_heads) float32
74
+ beta : (B, seq_len, num_v_heads) float32
75
+ output : (B, seq_len, num_v_heads, v_dim) bfloat16
76
+ """
77
+ ops.gdn_prefill(h_state, query, key, value, gate, beta, output)
78
+
79
+
80
+ def gdn_chunk2(
81
+ h_state: torch.Tensor, query: torch.Tensor, key: torch.Tensor,
82
+ value: torch.Tensor, gate: torch.Tensor, beta: torch.Tensor,
83
+ output: torch.Tensor, h_state_intermediate: torch.Tensor,
84
+ ) -> None:
85
+ """K=2 chunkwise verify (MTP draft length 1).
86
+
87
+ Writes the intermediate state after token 0 to ``h_state_intermediate``
88
+ so the caller can roll back when token 1 is rejected.
89
+ """
90
+ ops.gdn_chunk2(h_state, query, key, value, gate, beta, output, h_state_intermediate)
91
+
92
+
93
+ def gdn_chunk3(
94
+ h_state: torch.Tensor, query: torch.Tensor, key: torch.Tensor,
95
+ value: torch.Tensor, gate: torch.Tensor, beta: torch.Tensor,
96
+ output: torch.Tensor,
97
+ h_state_inter0: torch.Tensor, h_state_inter1: torch.Tensor,
98
+ ) -> None:
99
+ """K=3 chunkwise verify (MTP draft length 2)."""
100
+ ops.gdn_chunk3(h_state, query, key, value, gate, beta, output,
101
+ h_state_inter0, h_state_inter1)
102
+
103
+
104
+ def gdn_wy2(
105
+ h_state: torch.Tensor, query: torch.Tensor, key: torch.Tensor,
106
+ value: torch.Tensor, gate: torch.Tensor, beta: torch.Tensor,
107
+ output: torch.Tensor, h_state_intermediate: torch.Tensor,
108
+ ) -> None:
109
+ """2-pass WY-chunkwise K=2 verify (replaces chunk2 at higher acceptance)."""
110
+ ops.gdn_wy2(h_state, query, key, value, gate, beta, output, h_state_intermediate)
111
+
112
+
113
+ def gdn_wy3(
114
+ h_state: torch.Tensor, query: torch.Tensor, key: torch.Tensor,
115
+ value: torch.Tensor, gate: torch.Tensor, beta: torch.Tensor,
116
+ output: torch.Tensor, h_state_inter0: torch.Tensor, h_state_inter1: torch.Tensor,
117
+ ) -> None:
118
+ """2-pass WY-chunkwise K=3 verify."""
119
+ ops.gdn_wy3(h_state, query, key, value, gate, beta, output,
120
+ h_state_inter0, h_state_inter1)
121
+
122
+
123
+ def gdn_wy4(
124
+ h_state: torch.Tensor, query: torch.Tensor, key: torch.Tensor,
125
+ value: torch.Tensor, gate: torch.Tensor, beta: torch.Tensor,
126
+ output: torch.Tensor,
127
+ h_state_inter0: torch.Tensor,
128
+ h_state_inter1: torch.Tensor,
129
+ h_state_inter2: torch.Tensor,
130
+ ) -> None:
131
+ """2-pass WY-chunkwise K=4 verify."""
132
+ ops.gdn_wy4(h_state, query, key, value, gate, beta, output,
133
+ h_state_inter0, h_state_inter1, h_state_inter2)
134
+
135
+
136
+ def causal_conv1d_fwd(
137
+ x: torch.Tensor,
138
+ weight: torch.Tensor,
139
+ bias: Optional[torch.Tensor],
140
+ out: torch.Tensor,
141
+ ) -> None:
142
+ """Depthwise causal Conv1d forward (used by the SSM input projection).
143
+
144
+ x : (B, D, L) bfloat16
145
+ weight : (D, d_conv) bfloat16
146
+ bias : (D,) float32 or None
147
+ out : (B, D, L) bfloat16
148
+ """
149
+ ops.causal_conv1d_fwd(x, weight, bias, out)
150
+
151
+
152
+ def causal_conv1d_update(
153
+ conv_state: torch.Tensor,
154
+ x: torch.Tensor,
155
+ weight: torch.Tensor,
156
+ bias: Optional[torch.Tensor],
157
+ out: torch.Tensor,
158
+ ) -> None:
159
+ """Single-step causal Conv1d update (single-token decode path).
160
+
161
+ conv_state : (B, D, d_conv) float32, in-place updated (rolled left, last slot = x)
162
+ x : (B, D) bfloat16 (new input)
163
+ weight : (D, d_conv) bfloat16
164
+ bias : (D,) float32 or None
165
+ out : (B, D) bfloat16
166
+ """
167
+ ops.causal_conv1d_update(conv_state, x, weight, bias, out)
build/torch211-cxx11-cu130-aarch64-linux/_gdn_cuda_bcba8ad.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:92cdec32c4ef029a154eeb37ee7e889d00144c82496e59699c76175dd5df4084
3
+ size 3242624
build/torch211-cxx11-cu130-aarch64-linux/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _gdn_cuda_bcba8ad
3
+ ops = torch.ops._gdn_cuda_bcba8ad
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_gdn_cuda_bcba8ad::{op_name}"
build/torch211-cxx11-cu130-aarch64-linux/gdn/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ctypes
2
+ import importlib.util
3
+ import sys
4
+ from pathlib import Path
5
+ from types import ModuleType
6
+
7
+
8
+ def _import_from_path(file_path: Path) -> ModuleType:
9
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
10
+ # it would also be used for other imports. So, we make a module name that
11
+ # depends on the path for it to be unique using the hex-encoded hash of
12
+ # the path.
13
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
14
+ module_name = path_hash
15
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
16
+ if spec is None:
17
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
18
+ module = importlib.util.module_from_spec(spec)
19
+ if module is None:
20
+ raise ImportError(f"Cannot load module {module_name} from spec")
21
+ sys.modules[module_name] = module
22
+ spec.loader.exec_module(module) # type: ignore
23
+ return module
24
+
25
+
26
+ globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
build/torch211-cxx11-cu130-aarch64-linux/metadata.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "gdn",
3
+ "id": "_gdn_cuda_bcba8ad",
4
+ "version": 0,
5
+ "license": "AGPL-3.0-only",
6
+ "python-depends": [],
7
+ "backend": {
8
+ "type": "cuda"
9
+ }
10
+ }
build/torch212-cxx11-cu130-aarch64-linux/__init__.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: AGPL-3.0-only
2
+ """Atlas Gated DeltaNet kernels for NVIDIA GB10 (SM121).
3
+
4
+ These kernels back the linear-attention path of Qwen3.6 hybrid models
5
+ (27B dense and 35B-A3B sparse). They are hand-tuned for the unified
6
+ LPDDR5X memory layout of the DGX Spark and pinned to compute
7
+ capability 12.1 — they will not load on any other GPU.
8
+ """
9
+
10
+ from typing import Optional
11
+
12
+ import torch
13
+
14
+ from ._ops import ops
15
+
16
+ __all__ = [
17
+ "gdn_decode",
18
+ "gdn_prefill",
19
+ "gdn_chunk2",
20
+ "gdn_chunk3",
21
+ "gdn_wy2",
22
+ "gdn_wy3",
23
+ "gdn_wy4",
24
+ "causal_conv1d_fwd",
25
+ "causal_conv1d_update",
26
+ ]
27
+
28
+
29
+ def gdn_decode(
30
+ h_state: torch.Tensor,
31
+ query: torch.Tensor,
32
+ key: torch.Tensor,
33
+ value: torch.Tensor,
34
+ gate: torch.Tensor,
35
+ beta: torch.Tensor,
36
+ output: torch.Tensor,
37
+ ) -> None:
38
+ """Single-token GDN decode (in-place update of ``h_state`` and ``output``).
39
+
40
+ The recurrent path keeps Q/K/V in FP32 to avoid the precision drift
41
+ that BF16 inputs cause over long contexts in hybrid models.
42
+
43
+ Shapes
44
+ ------
45
+ h_state : (B, num_v_heads, k_dim, v_dim) float32, in-place updated
46
+ query : (B, num_k_heads, k_dim) float32
47
+ key : (B, num_k_heads, k_dim) float32
48
+ value : (B, num_v_heads, v_dim) float32
49
+ gate : (B, num_v_heads) float32 (exp(g_t) decay)
50
+ beta : (B, num_v_heads) float32 (sigmoid(b_t))
51
+ output : (B, num_v_heads, v_dim) bfloat16, in-place written
52
+ """
53
+ ops.gdn_decode(h_state, query, key, value, gate, beta, output)
54
+
55
+
56
+ def gdn_prefill(
57
+ h_state: torch.Tensor,
58
+ query: torch.Tensor,
59
+ key: torch.Tensor,
60
+ value: torch.Tensor,
61
+ gate: torch.Tensor,
62
+ beta: torch.Tensor,
63
+ output: torch.Tensor,
64
+ ) -> None:
65
+ """Multi-token GDN prefill (one batch, one chunk).
66
+
67
+ Shapes
68
+ ------
69
+ h_state : (B, num_v_heads, k_dim, v_dim) float32
70
+ query : (B, seq_len, num_k_heads, k_dim) bfloat16
71
+ key : (B, seq_len, num_k_heads, k_dim) bfloat16
72
+ value : (B, seq_len, num_v_heads, v_dim) bfloat16
73
+ gate : (B, seq_len, num_v_heads) float32
74
+ beta : (B, seq_len, num_v_heads) float32
75
+ output : (B, seq_len, num_v_heads, v_dim) bfloat16
76
+ """
77
+ ops.gdn_prefill(h_state, query, key, value, gate, beta, output)
78
+
79
+
80
+ def gdn_chunk2(
81
+ h_state: torch.Tensor, query: torch.Tensor, key: torch.Tensor,
82
+ value: torch.Tensor, gate: torch.Tensor, beta: torch.Tensor,
83
+ output: torch.Tensor, h_state_intermediate: torch.Tensor,
84
+ ) -> None:
85
+ """K=2 chunkwise verify (MTP draft length 1).
86
+
87
+ Writes the intermediate state after token 0 to ``h_state_intermediate``
88
+ so the caller can roll back when token 1 is rejected.
89
+ """
90
+ ops.gdn_chunk2(h_state, query, key, value, gate, beta, output, h_state_intermediate)
91
+
92
+
93
+ def gdn_chunk3(
94
+ h_state: torch.Tensor, query: torch.Tensor, key: torch.Tensor,
95
+ value: torch.Tensor, gate: torch.Tensor, beta: torch.Tensor,
96
+ output: torch.Tensor,
97
+ h_state_inter0: torch.Tensor, h_state_inter1: torch.Tensor,
98
+ ) -> None:
99
+ """K=3 chunkwise verify (MTP draft length 2)."""
100
+ ops.gdn_chunk3(h_state, query, key, value, gate, beta, output,
101
+ h_state_inter0, h_state_inter1)
102
+
103
+
104
+ def gdn_wy2(
105
+ h_state: torch.Tensor, query: torch.Tensor, key: torch.Tensor,
106
+ value: torch.Tensor, gate: torch.Tensor, beta: torch.Tensor,
107
+ output: torch.Tensor, h_state_intermediate: torch.Tensor,
108
+ ) -> None:
109
+ """2-pass WY-chunkwise K=2 verify (replaces chunk2 at higher acceptance)."""
110
+ ops.gdn_wy2(h_state, query, key, value, gate, beta, output, h_state_intermediate)
111
+
112
+
113
+ def gdn_wy3(
114
+ h_state: torch.Tensor, query: torch.Tensor, key: torch.Tensor,
115
+ value: torch.Tensor, gate: torch.Tensor, beta: torch.Tensor,
116
+ output: torch.Tensor, h_state_inter0: torch.Tensor, h_state_inter1: torch.Tensor,
117
+ ) -> None:
118
+ """2-pass WY-chunkwise K=3 verify."""
119
+ ops.gdn_wy3(h_state, query, key, value, gate, beta, output,
120
+ h_state_inter0, h_state_inter1)
121
+
122
+
123
+ def gdn_wy4(
124
+ h_state: torch.Tensor, query: torch.Tensor, key: torch.Tensor,
125
+ value: torch.Tensor, gate: torch.Tensor, beta: torch.Tensor,
126
+ output: torch.Tensor,
127
+ h_state_inter0: torch.Tensor,
128
+ h_state_inter1: torch.Tensor,
129
+ h_state_inter2: torch.Tensor,
130
+ ) -> None:
131
+ """2-pass WY-chunkwise K=4 verify."""
132
+ ops.gdn_wy4(h_state, query, key, value, gate, beta, output,
133
+ h_state_inter0, h_state_inter1, h_state_inter2)
134
+
135
+
136
+ def causal_conv1d_fwd(
137
+ x: torch.Tensor,
138
+ weight: torch.Tensor,
139
+ bias: Optional[torch.Tensor],
140
+ out: torch.Tensor,
141
+ ) -> None:
142
+ """Depthwise causal Conv1d forward (used by the SSM input projection).
143
+
144
+ x : (B, D, L) bfloat16
145
+ weight : (D, d_conv) bfloat16
146
+ bias : (D,) float32 or None
147
+ out : (B, D, L) bfloat16
148
+ """
149
+ ops.causal_conv1d_fwd(x, weight, bias, out)
150
+
151
+
152
+ def causal_conv1d_update(
153
+ conv_state: torch.Tensor,
154
+ x: torch.Tensor,
155
+ weight: torch.Tensor,
156
+ bias: Optional[torch.Tensor],
157
+ out: torch.Tensor,
158
+ ) -> None:
159
+ """Single-step causal Conv1d update (single-token decode path).
160
+
161
+ conv_state : (B, D, d_conv) float32, in-place updated (rolled left, last slot = x)
162
+ x : (B, D) bfloat16 (new input)
163
+ weight : (D, d_conv) bfloat16
164
+ bias : (D,) float32 or None
165
+ out : (B, D) bfloat16
166
+ """
167
+ ops.causal_conv1d_update(conv_state, x, weight, bias, out)
build/torch212-cxx11-cu130-aarch64-linux/_gdn_cuda_bcba8ad.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4696f54be4ebace851150e8a68f405417453a03f1f94d359721c8141b908df20
3
+ size 3242648
build/torch212-cxx11-cu130-aarch64-linux/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _gdn_cuda_bcba8ad
3
+ ops = torch.ops._gdn_cuda_bcba8ad
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_gdn_cuda_bcba8ad::{op_name}"
build/torch212-cxx11-cu130-aarch64-linux/gdn/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ctypes
2
+ import importlib.util
3
+ import sys
4
+ from pathlib import Path
5
+ from types import ModuleType
6
+
7
+
8
+ def _import_from_path(file_path: Path) -> ModuleType:
9
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
10
+ # it would also be used for other imports. So, we make a module name that
11
+ # depends on the path for it to be unique using the hex-encoded hash of
12
+ # the path.
13
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
14
+ module_name = path_hash
15
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
16
+ if spec is None:
17
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
18
+ module = importlib.util.module_from_spec(spec)
19
+ if module is None:
20
+ raise ImportError(f"Cannot load module {module_name} from spec")
21
+ sys.modules[module_name] = module
22
+ spec.loader.exec_module(module) # type: ignore
23
+ return module
24
+
25
+
26
+ globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
build/torch212-cxx11-cu130-aarch64-linux/metadata.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "gdn",
3
+ "id": "_gdn_cuda_bcba8ad",
4
+ "version": 0,
5
+ "license": "AGPL-3.0-only",
6
+ "python-depends": [],
7
+ "backend": {
8
+ "type": "cuda"
9
+ }
10
+ }
build/torch212-cxx11-cu132-aarch64-linux/__init__.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: AGPL-3.0-only
2
+ """Atlas Gated DeltaNet kernels for NVIDIA GB10 (SM121).
3
+
4
+ These kernels back the linear-attention path of Qwen3.6 hybrid models
5
+ (27B dense and 35B-A3B sparse). They are hand-tuned for the unified
6
+ LPDDR5X memory layout of the DGX Spark and pinned to compute
7
+ capability 12.1 — they will not load on any other GPU.
8
+ """
9
+
10
+ from typing import Optional
11
+
12
+ import torch
13
+
14
+ from ._ops import ops
15
+
16
+ __all__ = [
17
+ "gdn_decode",
18
+ "gdn_prefill",
19
+ "gdn_chunk2",
20
+ "gdn_chunk3",
21
+ "gdn_wy2",
22
+ "gdn_wy3",
23
+ "gdn_wy4",
24
+ "causal_conv1d_fwd",
25
+ "causal_conv1d_update",
26
+ ]
27
+
28
+
29
+ def gdn_decode(
30
+ h_state: torch.Tensor,
31
+ query: torch.Tensor,
32
+ key: torch.Tensor,
33
+ value: torch.Tensor,
34
+ gate: torch.Tensor,
35
+ beta: torch.Tensor,
36
+ output: torch.Tensor,
37
+ ) -> None:
38
+ """Single-token GDN decode (in-place update of ``h_state`` and ``output``).
39
+
40
+ The recurrent path keeps Q/K/V in FP32 to avoid the precision drift
41
+ that BF16 inputs cause over long contexts in hybrid models.
42
+
43
+ Shapes
44
+ ------
45
+ h_state : (B, num_v_heads, k_dim, v_dim) float32, in-place updated
46
+ query : (B, num_k_heads, k_dim) float32
47
+ key : (B, num_k_heads, k_dim) float32
48
+ value : (B, num_v_heads, v_dim) float32
49
+ gate : (B, num_v_heads) float32 (exp(g_t) decay)
50
+ beta : (B, num_v_heads) float32 (sigmoid(b_t))
51
+ output : (B, num_v_heads, v_dim) bfloat16, in-place written
52
+ """
53
+ ops.gdn_decode(h_state, query, key, value, gate, beta, output)
54
+
55
+
56
+ def gdn_prefill(
57
+ h_state: torch.Tensor,
58
+ query: torch.Tensor,
59
+ key: torch.Tensor,
60
+ value: torch.Tensor,
61
+ gate: torch.Tensor,
62
+ beta: torch.Tensor,
63
+ output: torch.Tensor,
64
+ ) -> None:
65
+ """Multi-token GDN prefill (one batch, one chunk).
66
+
67
+ Shapes
68
+ ------
69
+ h_state : (B, num_v_heads, k_dim, v_dim) float32
70
+ query : (B, seq_len, num_k_heads, k_dim) bfloat16
71
+ key : (B, seq_len, num_k_heads, k_dim) bfloat16
72
+ value : (B, seq_len, num_v_heads, v_dim) bfloat16
73
+ gate : (B, seq_len, num_v_heads) float32
74
+ beta : (B, seq_len, num_v_heads) float32
75
+ output : (B, seq_len, num_v_heads, v_dim) bfloat16
76
+ """
77
+ ops.gdn_prefill(h_state, query, key, value, gate, beta, output)
78
+
79
+
80
+ def gdn_chunk2(
81
+ h_state: torch.Tensor, query: torch.Tensor, key: torch.Tensor,
82
+ value: torch.Tensor, gate: torch.Tensor, beta: torch.Tensor,
83
+ output: torch.Tensor, h_state_intermediate: torch.Tensor,
84
+ ) -> None:
85
+ """K=2 chunkwise verify (MTP draft length 1).
86
+
87
+ Writes the intermediate state after token 0 to ``h_state_intermediate``
88
+ so the caller can roll back when token 1 is rejected.
89
+ """
90
+ ops.gdn_chunk2(h_state, query, key, value, gate, beta, output, h_state_intermediate)
91
+
92
+
93
+ def gdn_chunk3(
94
+ h_state: torch.Tensor, query: torch.Tensor, key: torch.Tensor,
95
+ value: torch.Tensor, gate: torch.Tensor, beta: torch.Tensor,
96
+ output: torch.Tensor,
97
+ h_state_inter0: torch.Tensor, h_state_inter1: torch.Tensor,
98
+ ) -> None:
99
+ """K=3 chunkwise verify (MTP draft length 2)."""
100
+ ops.gdn_chunk3(h_state, query, key, value, gate, beta, output,
101
+ h_state_inter0, h_state_inter1)
102
+
103
+
104
+ def gdn_wy2(
105
+ h_state: torch.Tensor, query: torch.Tensor, key: torch.Tensor,
106
+ value: torch.Tensor, gate: torch.Tensor, beta: torch.Tensor,
107
+ output: torch.Tensor, h_state_intermediate: torch.Tensor,
108
+ ) -> None:
109
+ """2-pass WY-chunkwise K=2 verify (replaces chunk2 at higher acceptance)."""
110
+ ops.gdn_wy2(h_state, query, key, value, gate, beta, output, h_state_intermediate)
111
+
112
+
113
+ def gdn_wy3(
114
+ h_state: torch.Tensor, query: torch.Tensor, key: torch.Tensor,
115
+ value: torch.Tensor, gate: torch.Tensor, beta: torch.Tensor,
116
+ output: torch.Tensor, h_state_inter0: torch.Tensor, h_state_inter1: torch.Tensor,
117
+ ) -> None:
118
+ """2-pass WY-chunkwise K=3 verify."""
119
+ ops.gdn_wy3(h_state, query, key, value, gate, beta, output,
120
+ h_state_inter0, h_state_inter1)
121
+
122
+
123
+ def gdn_wy4(
124
+ h_state: torch.Tensor, query: torch.Tensor, key: torch.Tensor,
125
+ value: torch.Tensor, gate: torch.Tensor, beta: torch.Tensor,
126
+ output: torch.Tensor,
127
+ h_state_inter0: torch.Tensor,
128
+ h_state_inter1: torch.Tensor,
129
+ h_state_inter2: torch.Tensor,
130
+ ) -> None:
131
+ """2-pass WY-chunkwise K=4 verify."""
132
+ ops.gdn_wy4(h_state, query, key, value, gate, beta, output,
133
+ h_state_inter0, h_state_inter1, h_state_inter2)
134
+
135
+
136
+ def causal_conv1d_fwd(
137
+ x: torch.Tensor,
138
+ weight: torch.Tensor,
139
+ bias: Optional[torch.Tensor],
140
+ out: torch.Tensor,
141
+ ) -> None:
142
+ """Depthwise causal Conv1d forward (used by the SSM input projection).
143
+
144
+ x : (B, D, L) bfloat16
145
+ weight : (D, d_conv) bfloat16
146
+ bias : (D,) float32 or None
147
+ out : (B, D, L) bfloat16
148
+ """
149
+ ops.causal_conv1d_fwd(x, weight, bias, out)
150
+
151
+
152
+ def causal_conv1d_update(
153
+ conv_state: torch.Tensor,
154
+ x: torch.Tensor,
155
+ weight: torch.Tensor,
156
+ bias: Optional[torch.Tensor],
157
+ out: torch.Tensor,
158
+ ) -> None:
159
+ """Single-step causal Conv1d update (single-token decode path).
160
+
161
+ conv_state : (B, D, d_conv) float32, in-place updated (rolled left, last slot = x)
162
+ x : (B, D) bfloat16 (new input)
163
+ weight : (D, d_conv) bfloat16
164
+ bias : (D,) float32 or None
165
+ out : (B, D) bfloat16
166
+ """
167
+ ops.causal_conv1d_update(conv_state, x, weight, bias, out)
build/torch212-cxx11-cu132-aarch64-linux/_gdn_cuda_bcba8ad.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:32e0475d1a711df63bf51e0cbfa413e10989afde5e11d43f80855972bf48c91c
3
+ size 3373720
build/torch212-cxx11-cu132-aarch64-linux/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _gdn_cuda_bcba8ad
3
+ ops = torch.ops._gdn_cuda_bcba8ad
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_gdn_cuda_bcba8ad::{op_name}"
build/torch212-cxx11-cu132-aarch64-linux/gdn/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ctypes
2
+ import importlib.util
3
+ import sys
4
+ from pathlib import Path
5
+ from types import ModuleType
6
+
7
+
8
+ def _import_from_path(file_path: Path) -> ModuleType:
9
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
10
+ # it would also be used for other imports. So, we make a module name that
11
+ # depends on the path for it to be unique using the hex-encoded hash of
12
+ # the path.
13
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
14
+ module_name = path_hash
15
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
16
+ if spec is None:
17
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
18
+ module = importlib.util.module_from_spec(spec)
19
+ if module is None:
20
+ raise ImportError(f"Cannot load module {module_name} from spec")
21
+ sys.modules[module_name] = module
22
+ spec.loader.exec_module(module) # type: ignore
23
+ return module
24
+
25
+
26
+ globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
build/torch212-cxx11-cu132-aarch64-linux/metadata.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "gdn",
3
+ "id": "_gdn_cuda_bcba8ad",
4
+ "version": 0,
5
+ "license": "AGPL-3.0-only",
6
+ "python-depends": [],
7
+ "backend": {
8
+ "type": "cuda"
9
+ }
10
+ }