Text Generation
Transformers
Safetensors
Chinese
neuronspark
snn
spiking-neural-network
neuromorphic
chat
conversational
custom_code
Instructions to use Brain2nd/NeuronSpark-0.9B-Chat with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Brain2nd/NeuronSpark-0.9B-Chat with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-generation", model="Brain2nd/NeuronSpark-0.9B-Chat", trust_remote_code=True) messages = [ {"role": "user", "content": "Who are you?"}, ] pipe(messages)# Load model directly from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained("Brain2nd/NeuronSpark-0.9B-Chat", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
- Local Apps
- vLLM
How to use Brain2nd/NeuronSpark-0.9B-Chat with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "Brain2nd/NeuronSpark-0.9B-Chat" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "Brain2nd/NeuronSpark-0.9B-Chat", "messages": [ { "role": "user", "content": "What is the capital of France?" } ] }'Use Docker
docker model run hf.co/Brain2nd/NeuronSpark-0.9B-Chat
- SGLang
How to use Brain2nd/NeuronSpark-0.9B-Chat with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "Brain2nd/NeuronSpark-0.9B-Chat" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "Brain2nd/NeuronSpark-0.9B-Chat", "messages": [ { "role": "user", "content": "What is the capital of France?" } ] }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "Brain2nd/NeuronSpark-0.9B-Chat" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "Brain2nd/NeuronSpark-0.9B-Chat", "messages": [ { "role": "user", "content": "What is the capital of France?" } ] }' - Docker Model Runner
How to use Brain2nd/NeuronSpark-0.9B-Chat with Docker Model Runner:
docker model run hf.co/Brain2nd/NeuronSpark-0.9B-Chat
| """ | |
| Parallel Scan 工具函数:SNN 线性递推的高效并行求解 | |
| 实现三层后端: | |
| 1. Fused PLIF kernel(默认,CUDA + Sigmoid surrogate): | |
| 单 kernel 完成 PLIF 前向(scan + spike + soft reset)和反向(surrogate gradient) | |
| · per-element beta/v_th: _fused_plif_fwd_kernel / _fused_plif_bwd_kernel | |
| · row-param beta/v_th: _fused_plif_fwd_rowparam_kernel / _fused_plif_bwd_rowparam_kernel | |
| 2. Triton linear_recurrence(CUDA,非 Sigmoid 或无 surrogate): | |
| 列级并行 scan,O(K) 工作量,1 次 kernel launch | |
| 3. Hillis-Steele parallel scan(CPU 回退):O(K log K) 工作量 | |
| 线性递推: | |
| V[k] = a[k] * V[k-1] + b[k], V[-1] = v_init | |
| PLIF 神经元动力学: | |
| V_pre[k] = beta[k] * V_post[k-1] + u[k] | |
| s[k] = Θ(V_pre[k] - v_th[k]) | |
| V_post[k] = V_pre[k] - v_th[k] * s[k] | |
| 数学原理见 SNN_SELECTIVE_STATE_SPACE.md。 | |
| """ | |
| import os | |
| import torch | |
| # ============================================================ | |
| # Triton fused recurrence kernels | |
| # ============================================================ | |
| # DGX Spark (GB10, sm_121a): Triton 3.5.1 自带 ptxas 不支持 sm_121a, | |
| # 需要使用系统 CUDA 13.0 的 ptxas | |
| _SYSTEM_PTXAS = '/usr/local/cuda-13.0/bin/ptxas' | |
| if os.path.exists(_SYSTEM_PTXAS) and 'TRITON_PTXAS_PATH' not in os.environ: | |
| os.environ['TRITON_PTXAS_PATH'] = _SYSTEM_PTXAS | |
| _HAS_TRITON = False | |
| try: | |
| import triton | |
| import triton.language as tl | |
| _HAS_TRITON = True | |
| except ImportError: | |
| pass | |
| if _HAS_TRITON: | |
| def _fwd_recurrence_kernel( | |
| A_ptr, B_ptr, INIT_ptr, OUT_ptr, | |
| K, num_cols, | |
| BLOCK: tl.constexpr, | |
| ): | |
| """Forward: V[k] = A[k]*V[k-1] + B[k], V[-1] = init. | |
| Grid: (ceil(num_cols / BLOCK),) | |
| Each program processes BLOCK columns across all K sequential steps. | |
| Accumulation in float32; storage in input dtype. | |
| """ | |
| pid = tl.program_id(0) | |
| cols = pid * BLOCK + tl.arange(0, BLOCK) | |
| mask = cols < num_cols | |
| v = tl.load(INIT_ptr + cols, mask=mask, other=0.0).to(tl.float32) | |
| for k in range(K): | |
| off = k * num_cols + cols | |
| a = tl.load(A_ptr + off, mask=mask, other=0.0).to(tl.float32) | |
| b = tl.load(B_ptr + off, mask=mask, other=0.0).to(tl.float32) | |
| v = a * v + b | |
| tl.store(OUT_ptr + off, v, mask=mask) | |
| def _bwd_recurrence_kernel( | |
| A_ptr, V_ptr, INIT_ptr, GRAD_OUT_ptr, | |
| GRAD_A_ptr, GRAD_B_ptr, GRAD_INIT_ptr, | |
| K, num_cols, | |
| BLOCK: tl.constexpr, | |
| ): | |
| """Backward for V[k] = A[k]*V[k-1] + B[k]. | |
| Reverse accumulation (k from K-1 down to 0): | |
| g = 0 | |
| for k = K-1, ..., 0: | |
| g += grad_out[k] | |
| grad_B[k] = g | |
| grad_A[k] = g * V[k-1] (V[-1] = init) | |
| g = g * A[k] | |
| grad_init = g | |
| """ | |
| pid = tl.program_id(0) | |
| cols = pid * BLOCK + tl.arange(0, BLOCK) | |
| mask = cols < num_cols | |
| g = tl.zeros([BLOCK], dtype=tl.float32) | |
| for k_rev in range(K): | |
| k = K - 1 - k_rev | |
| off = k * num_cols + cols | |
| dV = tl.load(GRAD_OUT_ptr + off, mask=mask, other=0.0).to(tl.float32) | |
| g = g + dV | |
| tl.store(GRAD_B_ptr + off, g, mask=mask) | |
| if k > 0: | |
| v_prev = tl.load( | |
| V_ptr + (k - 1) * num_cols + cols, | |
| mask=mask, other=0.0, | |
| ).to(tl.float32) | |
| else: | |
| v_prev = tl.load(INIT_ptr + cols, mask=mask, other=0.0).to(tl.float32) | |
| tl.store(GRAD_A_ptr + off, g * v_prev, mask=mask) | |
| a = tl.load(A_ptr + off, mask=mask, other=0.0).to(tl.float32) | |
| g = g * a | |
| tl.store(GRAD_INIT_ptr + cols, g, mask=mask) | |
| class _TritonLinearRecurrence(torch.autograd.Function): | |
| """Fused Triton linear recurrence: V[k] = A[k]*V[k-1] + B[k].""" | |
| _BLOCK = 128 | |
| def forward(ctx, beta, u, v_init): | |
| beta_c = beta.contiguous() | |
| u_c = u.contiguous() | |
| v_init_c = v_init.contiguous() | |
| K = beta_c.shape[0] | |
| num_cols = beta_c[0].numel() | |
| V = torch.empty_like(u_c) | |
| BLOCK = _TritonLinearRecurrence._BLOCK | |
| grid = ((num_cols + BLOCK - 1) // BLOCK,) | |
| _fwd_recurrence_kernel[grid]( | |
| beta_c, u_c, v_init_c, V, | |
| K, num_cols, | |
| BLOCK=BLOCK, | |
| ) | |
| if ctx.needs_input_grad[0] or ctx.needs_input_grad[1] or ctx.needs_input_grad[2]: | |
| ctx.save_for_backward(beta_c, V, v_init_c) | |
| ctx.K = K | |
| ctx.num_cols = num_cols | |
| return V | |
| def backward(ctx, grad_V): | |
| beta, V, v_init = ctx.saved_tensors | |
| grad_V_c = grad_V.contiguous() | |
| K = ctx.K | |
| num_cols = ctx.num_cols | |
| grad_beta = torch.empty_like(beta) | |
| grad_u = torch.empty_like(beta) | |
| grad_v_init = torch.empty_like(v_init) | |
| BLOCK = _TritonLinearRecurrence._BLOCK | |
| grid = ((num_cols + BLOCK - 1) // BLOCK,) | |
| _bwd_recurrence_kernel[grid]( | |
| beta, V, v_init, grad_V_c, | |
| grad_beta, grad_u, grad_v_init, | |
| K, num_cols, | |
| BLOCK=BLOCK, | |
| ) | |
| return grad_beta, grad_u, grad_v_init | |
| # ============================================================ | |
| # Fused PLIF forward/backward kernels | |
| # ============================================================ | |
| def _fused_plif_fwd_kernel( | |
| BETA_ptr, U_ptr, VTH_ptr, INIT_ptr, | |
| SPIKE_ptr, VPOST_ptr, | |
| K, num_cols, | |
| BLOCK: tl.constexpr, | |
| ): | |
| """Fused PLIF forward: single-pass sequential scan with inline spike + soft reset. | |
| Exact computation — sequential scan IS the ground truth. | |
| Replaces the 3-phase approach (linear scan + spike iteration + correction). | |
| Per column (parallel across batch*D): | |
| v = v_init | |
| for k = 0..K-1: | |
| v_pre = beta[k]*v + u[k] | |
| spike[k] = Θ(v_pre - v_th[k]) | |
| v = v_pre - v_th[k]*spike[k] | |
| """ | |
| pid = tl.program_id(0) | |
| cols = pid * BLOCK + tl.arange(0, BLOCK) | |
| mask = cols < num_cols | |
| v = tl.load(INIT_ptr + cols, mask=mask, other=0.0).to(tl.float32) | |
| for k in range(K): | |
| off = k * num_cols + cols | |
| beta = tl.load(BETA_ptr + off, mask=mask, other=0.0).to(tl.float32) | |
| u = tl.load(U_ptr + off, mask=mask, other=0.0).to(tl.float32) | |
| vth = tl.load(VTH_ptr + off, mask=mask, other=0.0).to(tl.float32) | |
| v_pre = beta * v + u | |
| spike = tl.where(v_pre >= vth, 1.0, 0.0) | |
| v = v_pre - vth * spike # soft reset | |
| tl.store(SPIKE_ptr + off, spike, mask=mask) | |
| tl.store(VPOST_ptr + off, v, mask=mask) | |
| def _fused_plif_bwd_kernel( | |
| BETA_ptr, VTH_ptr, INIT_ptr, VPOST_ptr, SPIKE_ptr, | |
| GRAD_SPIKE_ptr, GRAD_VPOST_ptr, | |
| GRAD_BETA_ptr, GRAD_U_ptr, GRAD_VTH_ptr, GRAD_INIT_ptr, | |
| K, num_cols, ALPHA, | |
| BLOCK: tl.constexpr, | |
| ): | |
| """Fused PLIF backward: single reverse pass with Sigmoid surrogate gradient. | |
| V_pre[k] = V_post[k] + v_th[k]*spike[k] (reconstructed) | |
| surrogate_grad(x) = alpha * sigmoid(alpha*x) * (1 - sigmoid(alpha*x)) | |
| where x = V_pre[k] - v_th[k] = V_post[k] - v_th[k]*(1 - spike[k]) | |
| Reverse accumulation: | |
| acc = 0 | |
| for k = K-1 downto 0: | |
| total_gV = grad_V_post[k] + acc | |
| sg = surrogate_grad(V_pre[k] - v_th[k]) | |
| grad_v_pre = grad_spike[k]*sg + total_gV | |
| grad_beta[k] = grad_v_pre * V_post[k-1] | |
| grad_u[k] = grad_v_pre | |
| grad_v_th[k] = -grad_spike[k]*sg - total_gV*spike[k] | |
| acc = grad_v_pre * beta[k] | |
| grad_v_init = acc | |
| """ | |
| pid = tl.program_id(0) | |
| cols = pid * BLOCK + tl.arange(0, BLOCK) | |
| mask = cols < num_cols | |
| acc = tl.zeros([BLOCK], dtype=tl.float32) | |
| for k_rev in range(K): | |
| k = K - 1 - k_rev | |
| off = k * num_cols + cols | |
| beta = tl.load(BETA_ptr + off, mask=mask, other=0.0).to(tl.float32) | |
| vth = tl.load(VTH_ptr + off, mask=mask, other=0.0).to(tl.float32) | |
| v_post = tl.load(VPOST_ptr + off, mask=mask, other=0.0).to(tl.float32) | |
| spike = tl.load(SPIKE_ptr + off, mask=mask, other=0.0).to(tl.float32) | |
| g_s = tl.load(GRAD_SPIKE_ptr + off, mask=mask, other=0.0).to(tl.float32) | |
| g_V = tl.load(GRAD_VPOST_ptr + off, mask=mask, other=0.0).to(tl.float32) | |
| # V_post[k-1] | |
| if k > 0: | |
| v_prev = tl.load( | |
| VPOST_ptr + (k - 1) * num_cols + cols, | |
| mask=mask, other=0.0, | |
| ).to(tl.float32) | |
| else: | |
| v_prev = tl.load(INIT_ptr + cols, mask=mask, other=0.0).to(tl.float32) | |
| # Sigmoid surrogate gradient | |
| x = v_post - vth * (1.0 - spike) # = V_pre - v_th | |
| neg_ax = -ALPHA * x | |
| neg_ax = tl.where(neg_ax > 88.0, 88.0, neg_ax) # prevent exp overflow | |
| sig = 1.0 / (1.0 + tl.exp(neg_ax)) | |
| sg = ALPHA * sig * (1.0 - sig) | |
| total_gV = g_V + acc | |
| grad_v_pre = g_s * sg + total_gV | |
| tl.store(GRAD_BETA_ptr + off, grad_v_pre * v_prev, mask=mask) | |
| tl.store(GRAD_U_ptr + off, grad_v_pre, mask=mask) | |
| tl.store(GRAD_VTH_ptr + off, -g_s * sg - total_gV * spike, mask=mask) | |
| acc = grad_v_pre * beta | |
| tl.store(GRAD_INIT_ptr + cols, acc, mask=mask) | |
| # ============================================================ | |
| # Fused PLIF kernels with row-parameter beta/v_th | |
| # (constant across K steps — e.g., ParametricLIFNode scalars) | |
| # ============================================================ | |
| def _fused_plif_fwd_rowparam_kernel( | |
| BETA_ROW_ptr, U_ptr, VTH_ROW_ptr, INIT_ptr, | |
| SPIKE_ptr, VPOST_ptr, | |
| K, num_cols, | |
| BLOCK: tl.constexpr, | |
| ): | |
| """Fused PLIF forward with row-parameter beta and v_th. | |
| beta and v_th are (*shape) — constant across K steps, loaded once into registers. | |
| Reduces global memory reads from 3 per step (beta, u, v_th) to 1 (u only). | |
| """ | |
| pid = tl.program_id(0) | |
| cols = pid * BLOCK + tl.arange(0, BLOCK) | |
| mask = cols < num_cols | |
| v = tl.load(INIT_ptr + cols, mask=mask, other=0.0).to(tl.float32) | |
| beta = tl.load(BETA_ROW_ptr + cols, mask=mask, other=0.0).to(tl.float32) | |
| vth = tl.load(VTH_ROW_ptr + cols, mask=mask, other=0.0).to(tl.float32) | |
| for k in range(K): | |
| off = k * num_cols + cols | |
| u = tl.load(U_ptr + off, mask=mask, other=0.0).to(tl.float32) | |
| v_pre = beta * v + u | |
| spike = tl.where(v_pre >= vth, 1.0, 0.0) | |
| v = v_pre - vth * spike | |
| tl.store(SPIKE_ptr + off, spike, mask=mask) | |
| tl.store(VPOST_ptr + off, v, mask=mask) | |
| def _fused_plif_bwd_rowparam_kernel( | |
| BETA_ROW_ptr, VTH_ROW_ptr, INIT_ptr, VPOST_ptr, SPIKE_ptr, | |
| GRAD_SPIKE_ptr, GRAD_VPOST_ptr, | |
| GRAD_BETA_ROW_ptr, GRAD_U_ptr, GRAD_VTH_ROW_ptr, GRAD_INIT_ptr, | |
| K, num_cols, ALPHA, | |
| BLOCK: tl.constexpr, | |
| ): | |
| """Fused PLIF backward with row-parameter beta/v_th. | |
| Gradients for beta and v_th are accumulated over K steps (reduction in registers). | |
| Returns grad_beta_row (*shape) and grad_v_th_row (*shape) instead of per-step gradients. | |
| """ | |
| pid = tl.program_id(0) | |
| cols = pid * BLOCK + tl.arange(0, BLOCK) | |
| mask = cols < num_cols | |
| beta = tl.load(BETA_ROW_ptr + cols, mask=mask, other=0.0).to(tl.float32) | |
| vth = tl.load(VTH_ROW_ptr + cols, mask=mask, other=0.0).to(tl.float32) | |
| acc = tl.zeros([BLOCK], dtype=tl.float32) | |
| acc_grad_beta = tl.zeros([BLOCK], dtype=tl.float32) | |
| acc_grad_vth = tl.zeros([BLOCK], dtype=tl.float32) | |
| for k_rev in range(K): | |
| k = K - 1 - k_rev | |
| off = k * num_cols + cols | |
| v_post = tl.load(VPOST_ptr + off, mask=mask, other=0.0).to(tl.float32) | |
| spike = tl.load(SPIKE_ptr + off, mask=mask, other=0.0).to(tl.float32) | |
| g_s = tl.load(GRAD_SPIKE_ptr + off, mask=mask, other=0.0).to(tl.float32) | |
| g_V = tl.load(GRAD_VPOST_ptr + off, mask=mask, other=0.0).to(tl.float32) | |
| if k > 0: | |
| v_prev = tl.load( | |
| VPOST_ptr + (k - 1) * num_cols + cols, | |
| mask=mask, other=0.0, | |
| ).to(tl.float32) | |
| else: | |
| v_prev = tl.load(INIT_ptr + cols, mask=mask, other=0.0).to(tl.float32) | |
| # Sigmoid surrogate gradient | |
| x = v_post - vth * (1.0 - spike) | |
| neg_ax = -ALPHA * x | |
| neg_ax = tl.where(neg_ax > 88.0, 88.0, neg_ax) | |
| sig = 1.0 / (1.0 + tl.exp(neg_ax)) | |
| sg = ALPHA * sig * (1.0 - sig) | |
| total_gV = g_V + acc | |
| grad_v_pre = g_s * sg + total_gV | |
| tl.store(GRAD_U_ptr + off, grad_v_pre, mask=mask) | |
| # Accumulate gradients for row parameters (reduction over K in registers) | |
| acc_grad_beta += grad_v_pre * v_prev | |
| acc_grad_vth += -g_s * sg - total_gV * spike | |
| acc = grad_v_pre * beta | |
| tl.store(GRAD_INIT_ptr + cols, acc, mask=mask) | |
| tl.store(GRAD_BETA_ROW_ptr + cols, acc_grad_beta, mask=mask) | |
| tl.store(GRAD_VTH_ROW_ptr + cols, acc_grad_vth, mask=mask) | |
| class _TritonPLIFRowParamForward(torch.autograd.Function): | |
| """Fused Triton PLIF with row-parameter beta/v_th. | |
| For neurons with constant beta/v_th across K steps (ParametricLIFNode). | |
| Eliminates expand+contiguous for beta/v_th tensors, reduces memory I/O by ~40%. | |
| """ | |
| _BLOCK = 128 | |
| def forward(ctx, beta_row, u, v_th_row, v_init, alpha): | |
| beta_row_c = beta_row.contiguous() | |
| u_c = u.contiguous() | |
| v_th_row_c = v_th_row.contiguous() | |
| v_init_c = v_init.contiguous() | |
| K = u_c.shape[0] | |
| num_cols = u_c[0].numel() | |
| spike = torch.empty_like(u_c) | |
| V_post = torch.empty_like(u_c) | |
| BLOCK = _TritonPLIFRowParamForward._BLOCK | |
| grid = ((num_cols + BLOCK - 1) // BLOCK,) | |
| _fused_plif_fwd_rowparam_kernel[grid]( | |
| beta_row_c, u_c, v_th_row_c, v_init_c, | |
| spike, V_post, | |
| K, num_cols, | |
| BLOCK=BLOCK, | |
| ) | |
| if any(ctx.needs_input_grad[:4]): | |
| ctx.save_for_backward(beta_row_c, v_th_row_c, v_init_c, V_post, spike) | |
| ctx.K = K | |
| ctx.num_cols = num_cols | |
| ctx.alpha = alpha | |
| return spike, V_post | |
| def backward(ctx, grad_spike, grad_V_post): | |
| beta_row, v_th_row, v_init, V_post, spike = ctx.saved_tensors | |
| K = ctx.K | |
| num_cols = ctx.num_cols | |
| alpha = ctx.alpha | |
| if grad_spike is None: | |
| grad_spike = torch.zeros_like(spike) | |
| if grad_V_post is None: | |
| grad_V_post = torch.zeros_like(V_post) | |
| grad_spike_c = grad_spike.contiguous() | |
| grad_V_post_c = grad_V_post.contiguous() | |
| grad_beta_row = torch.empty_like(beta_row) | |
| grad_u = torch.empty_like(V_post) | |
| grad_v_th_row = torch.empty_like(v_th_row) | |
| grad_v_init = torch.empty_like(v_init) | |
| BLOCK = _TritonPLIFRowParamForward._BLOCK | |
| grid = ((num_cols + BLOCK - 1) // BLOCK,) | |
| _fused_plif_bwd_rowparam_kernel[grid]( | |
| beta_row, v_th_row, v_init, V_post, spike, | |
| grad_spike_c, grad_V_post_c, | |
| grad_beta_row, grad_u, grad_v_th_row, grad_v_init, | |
| K, num_cols, float(alpha), | |
| BLOCK=BLOCK, | |
| ) | |
| return grad_beta_row, grad_u, grad_v_th_row, grad_v_init, None | |
| class _TritonPLIFForward(torch.autograd.Function): | |
| """Fused Triton PLIF forward + backward. | |
| Single-pass sequential scan replaces the 3-phase approach: | |
| Phase 1 (linear scan) + Phase 2 (spike iteration) + Phase 3 (correction) | |
| → 1 fused kernel with inline spike detection + soft reset | |
| Advantages: | |
| - 1 kernel launch (vs 3-4 launches + ~10 element-wise ops) | |
| - Exact computation (no iteration convergence issues) | |
| - Less memory (no intermediate V_L, delta_S, delta_S_prev) | |
| - Higher precision (fp32 accumulation, no bf16 intermediate store/load) | |
| """ | |
| _BLOCK = 128 | |
| def forward(ctx, beta, u, v_th, v_init, alpha): | |
| beta_c = beta.contiguous() | |
| u_c = u.contiguous() | |
| v_th_c = v_th.contiguous() | |
| v_init_c = v_init.contiguous() | |
| K = beta_c.shape[0] | |
| num_cols = beta_c[0].numel() | |
| spike = torch.empty_like(u_c) | |
| V_post = torch.empty_like(u_c) | |
| BLOCK = _TritonPLIFForward._BLOCK | |
| grid = ((num_cols + BLOCK - 1) // BLOCK,) | |
| _fused_plif_fwd_kernel[grid]( | |
| beta_c, u_c, v_th_c, v_init_c, | |
| spike, V_post, | |
| K, num_cols, | |
| BLOCK=BLOCK, | |
| ) | |
| if any(ctx.needs_input_grad[:4]): | |
| ctx.save_for_backward(beta_c, v_th_c, v_init_c, V_post, spike) | |
| ctx.K = K | |
| ctx.num_cols = num_cols | |
| ctx.alpha = alpha | |
| return spike, V_post | |
| def backward(ctx, grad_spike, grad_V_post): | |
| beta, v_th, v_init, V_post, spike = ctx.saved_tensors | |
| K = ctx.K | |
| num_cols = ctx.num_cols | |
| alpha = ctx.alpha | |
| if grad_spike is None: | |
| grad_spike = torch.zeros_like(spike) | |
| if grad_V_post is None: | |
| grad_V_post = torch.zeros_like(V_post) | |
| grad_spike_c = grad_spike.contiguous() | |
| grad_V_post_c = grad_V_post.contiguous() | |
| grad_beta = torch.empty_like(beta) | |
| grad_u = torch.empty_like(beta) | |
| grad_v_th = torch.empty_like(v_th) | |
| grad_v_init = torch.empty_like(v_init) | |
| BLOCK = _TritonPLIFForward._BLOCK | |
| grid = ((num_cols + BLOCK - 1) // BLOCK,) | |
| _fused_plif_bwd_kernel[grid]( | |
| beta, v_th, v_init, V_post, spike, | |
| grad_spike_c, grad_V_post_c, | |
| grad_beta, grad_u, grad_v_th, grad_v_init, | |
| K, num_cols, float(alpha), | |
| BLOCK=BLOCK, | |
| ) | |
| return grad_beta, grad_u, grad_v_th, grad_v_init, None | |
| # ============================================================ | |
| # Hillis-Steele parallel prefix scan (CPU fallback) | |
| # ============================================================ | |
| def hillis_steele_scan(a: torch.Tensor, b: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Hillis-Steele 并行前缀扫描:计算仿射映射序列的所有前缀复合。 | |
| 给定仿射映射 f_k(x) = a[k] * x + b[k], k = 0, ..., K-1, | |
| 计算前缀复合 F_k = f_k ∘ f_{k-1} ∘ ... ∘ f_0, | |
| 使得 V[k] = F_k(v_init) = A[k] * v_init + B[k]。 | |
| 复合规则: (a2, b2) ∘ (a1, b1) = (a2 * a1, a2 * b1 + b2) | |
| 实现使用 torch.cat 重建张量(无原地操作),完全兼容 autograd。 | |
| Args: | |
| a: (K, *shape) — 乘性系数(如 β) | |
| b: (K, *shape) — 加性项(如 α·I) | |
| Returns: | |
| A: (K, *shape) — 前缀积 A[k] = ∏_{j=0}^{k} a[j] | |
| B: (K, *shape) — 前缀和 B[k] 使得 V[k] = A[k] * v_init + B[k] | |
| 并行深度: O(log K) | |
| 工作量: O(K * log K) | |
| """ | |
| K = a.shape[0] | |
| A = a | |
| B = b | |
| d = 1 | |
| while d < K: | |
| A_new_tail = A[d:] * A[:-d] | |
| B_new_tail = A[d:] * B[:-d] + B[d:] | |
| A = torch.cat([A[:d], A_new_tail], dim=0) | |
| B = torch.cat([B[:d], B_new_tail], dim=0) | |
| d *= 2 | |
| return A, B | |
| # ============================================================ | |
| # Public API: linear_recurrence | |
| # ============================================================ | |
| def linear_recurrence(beta: torch.Tensor, u: torch.Tensor, v_init: torch.Tensor) -> torch.Tensor: | |
| """ | |
| 求解线性递推: V[k] = beta[k] * V[k-1] + u[k], V[-1] = v_init | |
| CUDA 后端: Triton fused kernel(1 次 kernel launch,O(K) 工作量) | |
| CPU 后端: Hillis-Steele parallel scan(O(K log K) 工作量) | |
| Args: | |
| beta: (K, *shape) — 衰减系数,值域 (0, 1) | |
| u: (K, *shape) — 输入项 | |
| v_init: (*shape) — 初始状态 | |
| Returns: | |
| V: (K, *shape) — 所有 K 步的状态 | |
| """ | |
| if _HAS_TRITON and beta.is_cuda: | |
| return _TritonLinearRecurrence.apply(beta, u, v_init) | |
| # CPU fallback | |
| A, B = hillis_steele_scan(beta, u) | |
| V = A * v_init.unsqueeze(0) + B | |
| return V | |
| # ============================================================ | |
| # PLIF parallel forward (with spike iteration) | |
| # ============================================================ | |
| def plif_parallel_forward( | |
| beta: torch.Tensor, | |
| u: torch.Tensor, | |
| v_th: torch.Tensor, | |
| v_init: torch.Tensor, | |
| max_iter: int = 3, | |
| surrogate_function=None, | |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| """ | |
| PLIF 神经元的并行前向传播(soft reset,surrogate gradient 兼容)。 | |
| 求解: | |
| V_pre[k] = beta[k] * V_post[k-1] + u[k] | |
| s[k] = Θ(V_pre[k] - v_th[k]) | |
| V_post[k] = V_pre[k] - v_th[k] * s[k] | |
| 方法: | |
| Phase 1: 线性轨迹 parallel scan(有梯度) | |
| Phase 2: spike 不动点迭代(detach,确定离散 spike pattern) | |
| Phase 3: 用 converged spike pattern 重算 V_post(有梯度), | |
| surrogate_function(V_pre - v_th) 生成可微 spike 输出 | |
| Args: | |
| beta: (K, *shape) — 衰减系数 | |
| u: (K, *shape) — 输入 α·I | |
| v_th: (K, *shape) — 动态阈值 | |
| v_init: (*shape) — 初始膜电位 | |
| max_iter: spike 不动点迭代次数上限 | |
| surrogate_function: SpikingJelly surrogate gradient 函数(如 surrogate.Sigmoid(alpha=4.0)) | |
| None 时退化为硬阈值(无梯度) | |
| Returns: | |
| spike: (K, *shape) — spike 模式(有 surrogate gradient) | |
| V_post: (K, *shape) — 发放后膜电位 | |
| V_pre: (K, *shape) — 发放前膜电位(fused path 返回 None) | |
| """ | |
| # Fused Triton path: single-pass sequential scan (exact, no iteration) | |
| # Replaces 3-phase approach with 1 kernel launch — ~3x faster forward, ~5x faster backward | |
| if (_HAS_TRITON and beta.is_cuda and surrogate_function is not None | |
| and hasattr(surrogate_function, 'alpha') | |
| and type(surrogate_function).__name__ == 'Sigmoid'): | |
| alpha = float(surrogate_function.alpha) | |
| spike, V_post = _TritonPLIFForward.apply(beta, u, v_th, v_init, alpha) | |
| return spike, V_post, None | |
| # Fallback: 3-phase approach (CPU, non-Sigmoid surrogates, or no surrogate) | |
| # Phase 1: 线性轨迹 V_L (假设从不发放) | |
| V_L = linear_recurrence(beta, u, v_init) # (K, *shape) | |
| # Phase 2: Spike 不动点迭代(全部 detach,不建立梯度图) | |
| # 目的:确定哪些神经元在哪些步发放(离散决策) | |
| with torch.no_grad(): | |
| V_L_det = V_L.detach() | |
| beta_det = beta.detach() | |
| v_th_det = v_th.detach() | |
| v_init_det = v_init.detach() if isinstance(v_init, torch.Tensor) else v_init | |
| spike_pattern = (V_L_det >= v_th_det).float() | |
| for _ in range(max_iter - 1): | |
| # 计算 ΔS: ΔS[k] = beta[k] * ΔS[k-1] + v_th[k] * s[k] | |
| delta_S = linear_recurrence( | |
| beta_det, v_th_det * spike_pattern, | |
| torch.zeros_like(v_init_det) if isinstance(v_init_det, torch.Tensor) | |
| else torch.zeros_like(V_L_det[0]), | |
| ) | |
| # ΔS_prev = ΔS[k-1](位移一步) | |
| delta_S_prev = torch.zeros_like(delta_S) | |
| delta_S_prev[1:] = delta_S[:-1] | |
| # V_pre = V_L - beta * ΔS_prev | |
| V_pre_det = V_L_det - beta_det * delta_S_prev | |
| # 更新 spike | |
| spike_new = (V_pre_det >= v_th_det).float() | |
| # 收敛检查 | |
| if torch.equal(spike_new, spike_pattern): | |
| break | |
| spike_pattern = spike_new | |
| # Phase 3: 用 converged spike pattern 重算 V_post(有完整梯度) | |
| # spike_pattern 是 detached 的,作为常数参与计算 | |
| # 梯度通过 u, v_th, beta, v_init 流动 | |
| u_eff = u - v_th * spike_pattern | |
| V_post = linear_recurrence(beta, u_eff, v_init) # (K, *shape) | |
| # 重建 V_pre(有梯度,用于 surrogate gradient) | |
| V_post_prev = torch.zeros_like(V_post) | |
| if isinstance(v_init, torch.Tensor): | |
| V_post_prev[0] = v_init | |
| V_post_prev[1:] = V_post[:-1] | |
| V_pre = beta * V_post_prev + u | |
| # 生成可微 spike 输出 | |
| if surrogate_function is not None: | |
| # forward: Heaviside(V_pre - v_th), backward: surrogate gradient | |
| spike = surrogate_function(V_pre - v_th) | |
| else: | |
| # 退化模式:硬阈值,无梯度 | |
| spike = (V_pre >= v_th).float() | |
| return spike, V_post, V_pre | |
| def plif_rowparam_forward( | |
| beta_row: torch.Tensor, | |
| u: torch.Tensor, | |
| v_th_row: torch.Tensor, | |
| v_init: torch.Tensor, | |
| surrogate_function=None, | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| 行参数 PLIF 前向:beta 和 v_th 在 K 步中保持恒定。 | |
| 比 plif_parallel_forward 快 ~40%(省去 expand+contiguous,减少 2/3 显存读取)。 | |
| 用于 ParametricLIFNode(固定 beta/v_th)或合并多个固定参数神经元。 | |
| Args: | |
| beta_row: (*shape) — 每列的衰减率(所有 K 步相同) | |
| u: (K, *shape) — 每步输入 | |
| v_th_row: (*shape) — 每列的阈值(所有 K 步相同) | |
| v_init: (*shape) — 初始膜电位 | |
| surrogate_function: surrogate gradient 函数 | |
| Returns: | |
| spike: (K, *shape) — spike 模式 | |
| V_post: (K, *shape) — 发放后膜电位 | |
| """ | |
| if (_HAS_TRITON and u.is_cuda and surrogate_function is not None | |
| and hasattr(surrogate_function, 'alpha') | |
| and type(surrogate_function).__name__ == 'Sigmoid'): | |
| alpha = float(surrogate_function.alpha) | |
| spike, V_post = _TritonPLIFRowParamForward.apply( | |
| beta_row, u, v_th_row, v_init, alpha, | |
| ) | |
| return spike, V_post | |
| # Fallback: expand to full (K, *shape) and use standard path | |
| K = u.shape[0] | |
| beta = beta_row.unsqueeze(0).expand(K, *u.shape[1:]).contiguous() | |
| v_th = v_th_row.unsqueeze(0).expand(K, *u.shape[1:]).contiguous() | |
| spike, V_post, _ = plif_parallel_forward(beta, u, v_th, v_init, surrogate_function=surrogate_function) | |
| return spike, V_post | |
| def plif_fixed_param_forward( | |
| beta, | |
| u: torch.Tensor, | |
| v_th, | |
| v_init: torch.Tensor, | |
| max_iter: int = 3, | |
| surrogate_function=None, | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| 固定参数 PLIF 神经元的并行前向(如输出神经元、FFN 神经元)。 | |
| ParametricLIFNode 方程: V[k] = beta * V[k-1] + (1-beta) * x[k] | |
| 其中 beta = 1/(1+exp(w)), 可为 scalar tensor(保持梯度流向 w)。 | |
| scalar/0-dim beta 和 v_th 使用 row-param 内核(无需 expand 到 (K, *shape))。 | |
| Args: | |
| beta: 衰减率 — scalar float、0-dim tensor 或 (K, *shape) tensor | |
| u: (K, *shape) — 输入(已乘以 (1-beta)) | |
| v_th: 阈值 — scalar float、0-dim tensor 或 (K, *shape) tensor | |
| v_init: (*shape) — 初始膜电位 | |
| max_iter: spike 迭代次数 | |
| surrogate_function: surrogate gradient 函数 | |
| Returns: | |
| spike: (K, *shape) — spike 模式 | |
| V_post: (K, *shape) — 发放后膜电位 | |
| """ | |
| K = u.shape[0] | |
| shape = u.shape[1:] | |
| # Row-param fast path: beta 和 v_th 都是 scalar/0-dim → 扩展为 (*shape) 行向量 | |
| beta_is_scalar = isinstance(beta, torch.Tensor) and beta.dim() == 0 | |
| beta_is_float = not isinstance(beta, torch.Tensor) | |
| vth_is_scalar = isinstance(v_th, torch.Tensor) and v_th.dim() == 0 | |
| vth_is_float = not isinstance(v_th, torch.Tensor) | |
| if (beta_is_scalar or beta_is_float) and (vth_is_scalar or vth_is_float): | |
| # Build row vectors (*shape) | |
| if beta_is_scalar: | |
| beta_row = beta.expand(*shape).contiguous() | |
| else: | |
| beta_row = torch.full(shape, beta, device=u.device, dtype=u.dtype) | |
| if vth_is_scalar: | |
| v_th_row = v_th.expand(*shape).contiguous() | |
| else: | |
| v_th_row = torch.full(shape, v_th, device=u.device, dtype=u.dtype) | |
| return plif_rowparam_forward(beta_row, u, v_th_row, v_init, surrogate_function) | |
| # Full-tensor path: expand to (K, *shape) if needed | |
| if isinstance(beta, torch.Tensor): | |
| if beta.dim() == 0: | |
| beta = beta.expand(K, *shape).contiguous() | |
| else: | |
| beta = torch.full_like(u, beta) | |
| if isinstance(v_th, torch.Tensor): | |
| if v_th.dim() == 0: | |
| v_th = v_th.expand(K, *shape).contiguous() | |
| else: | |
| v_th = torch.full_like(u, v_th) | |
| spike, V_post, _ = plif_parallel_forward( | |
| beta, u, v_th, v_init, max_iter, surrogate_function, | |
| ) | |
| return spike, V_post | |