// Tier D — multi-layer convolutional EqProp trainer (end-to-end, not greedy). // Architecture: N conv layers followed by one dense readout. All algorithmic — N is determined // by the length of the convCfgs array passed to the constructor. // // Critical design: // * Each conv layer has its own W, b, U-state-per-phase buffer set. // * Each conv layer receives top-down feedback from the next layer: // - For NON-last conv (layer l 1 the conv-transpose covers strictly fewer positions per kernel offset (some // (iy, ix) have no preimage in (yo, xo)) — this is handled by `if(integer && in_range)` checks. import { orth as orthCPU } from './eqprop_lib.js'; const PHASE_F = 0, PHASE_P = 1, PHASE_M = 2; // WGSL: bottom-up conv pass. // Output: writes u_state[b, k, y, x] = u_old + dt * (-u_old + σ(c)) // where c = bias + Σ kernel·input + (gamma * top-down if has_topdown). // has_topdown_type: 0=none, 1=dense-next (Wnxt is [O × this_flat]), 2=conv-next (Wnxt is conv kernel [Cnxt × Cthis × KHnxt × KWnxt]) const WGSL_CONV_RELAX_MULTI = ` struct CP { B: u32, Cin: u32, Cout: u32, H: u32, W: u32, Hout: u32, Wout: u32, KH: u32, KW: u32, stride: u32, pad: u32, _p0: u32, dt: f32, beta_unused: f32, gamma: f32, mode: f32, has_topdown_type: u32, nxt_O: u32, nxt_KH: u32, nxt_KW: u32, nxt_stride: u32, nxt_pad: u32, nxt_Cnxt: u32, _p2: u32, clamp_lo: f32, clamp_hi: f32, triangle_offset: f32, triangle_power: f32, // MSMEN-MVT: stochastic spike-sampling mode (subset of Tempered Markov Energy Network) // spike_mode > 0: at each relax iter, sample s ~ Bernoulli(σ(c)) using iter_seed-derived PCG hash; // update u via running mean of spikes so the dense readout sees a fair estimate. // For inference-time M-sample ensembling, caller sets a different iter_seed_base per sample. spike_mode: u32, iter_index: u32, iter_seed_base: u32, _p3: u32, // SI-5: dense → conv skip connection. When has_skip=1, conv layer reads an // ADDITIONAL top-down from the LAST DENSE LAYER via a learnable W_skip[skip_O × this_flat]. // Bypasses γ^L attenuation in deep stacks. skip_gamma controls its strength independently. has_skip: u32, skip_O: u32, skip_gamma: f32, _p4: u32, }; @group(0) @binding(0) var p : CP; @group(0) @binding(1) var Xin : array; // [B*Cin*H*W] input map @group(0) @binding(2) var Wt : array; // [Cout*Cin*KH*KW] @group(0) @binding(3) var Bs : array; // [Cout] @group(0) @binding(4) var Wnxt: array; // top-down weights (dense or conv kernel) @group(0) @binding(5) var Uh : array; // [B*Cout*Hout*Wout] @group(0) @binding(6) var Unxt: array; // [B*nxt_O] dense or [B*Cnxt*Hnxt*Wnxt] conv @group(0) @binding(7) var Tau : array; // [Cout] per-channel τ (HPSN); broadcast across spatial @group(0) @binding(8) var Wskip: array; // SI-5 [skip_O × this_flat] dense→conv skip W @group(0) @binding(9) var Uskip: array; // SI-5 [B × skip_O] last dense's state // Activations supported (mode flag): // 0 = adaptive σ (default, baseline) // 1 = fhn clip ρ(u) = clamp(u, 0, 1) // 2 = prism softplus smooth approximation of clip with bilateral gradient // 3 = triangle Krotov ρ(u) = max(0, u - triangle_offset)^triangle_power // — offset is set externally (algorithmic; commonly the per-layer mean) // — power=1 gives RePU; power=2 gives RePU² // // Tau is per-output-channel time constant; replaces global p.dt. Constant Tau[k]=p.dt → identical // to scalar-dt behavior (used for backward-compat default). const PRISM_K : f32 = 10.0; fn sg(u: f32) -> f32 { return 1.0 / (1.0 + exp(-4.0 * (u - 0.5))); } fn softplus_safe(x: f32) -> f32 { return select(x + log(1.0 + exp(-x)), log(1.0 + exp(x)), x <= 0.0); } fn prism_rho_c(u: f32) -> f32 { return (softplus_safe(PRISM_K * u) - softplus_safe(PRISM_K * (u - 1.0))) / PRISM_K; } fn triangle_rho_c(u: f32, off: f32, pwr: f32) -> f32 { let z = u - off; if (z <= 0.0) { return 0.0; } if (pwr == 1.0) { return z; } if (pwr == 2.0) { return z * z; } return pow(z, pwr); } fn rho(u: f32) -> f32 { // mode dispatch: 0 sigma, 1 clip, 2 prism, 3 triangle. // p.mode is uniform; all branches compile, one path runs per thread. if (p.mode > 2.5) { return triangle_rho_c(u, p.triangle_offset, p.triangle_power); } if (p.mode > 1.5) { return prism_rho_c(u); } if (p.mode > 0.5) { return clamp(u, 0.0, 1.0); } return sg(u); } // MSMEN-MVT: PCG-hash uniform sample in [0, 1). Deterministic for given seed. fn pcg_u32(seed_in: u32) -> u32 { var state : u32 = seed_in * 747796405u + 2891336453u; let word : u32 = ((state >> ((state >> 28u) + 4u)) ^ state) * 277803737u; return (word >> 22u) ^ word; } fn pcg_unit(b: u32, i: u32, t: u32, base: u32) -> f32 { // Compose per-(batch, neuron, iter, seed_base) — independent samples across all axes. let s = b * 1000003u + i * 2654435761u + t * 374761393u + base * 2246822519u; return f32(pcg_u32(s)) / 4294967296.0; } @compute @workgroup_size(8, 8, 1) fn conv_pass(@builtin(global_invocation_id) gid: vec3) { let xo = gid.x; let yo = gid.y; let bk = gid.z; if (xo >= p.Wout || yo >= p.Hout) { return; } let b = bk / p.Cout; let k = bk % p.Cout; if (b >= p.B) { return; } let img_size = p.Cin * p.H * p.W; let map_size = p.Cout * p.Hout * p.Wout; // Bottom-up var c : f32 = Bs[k]; for (var kin: u32 = 0u; kin < p.Cin; kin = kin + 1u) { for (var dy: u32 = 0u; dy < p.KH; dy = dy + 1u) { let iy_s = i32(yo * p.stride + dy) - i32(p.pad); if (iy_s < 0 || iy_s >= i32(p.H)) { continue; } let iy = u32(iy_s); for (var dx: u32 = 0u; dx < p.KW; dx = dx + 1u) { let ix_s = i32(xo * p.stride + dx) - i32(p.pad); if (ix_s < 0 || ix_s >= i32(p.W)) { continue; } let ix = u32(ix_s); let xidx = b * img_size + kin * p.H * p.W + iy * p.W + ix; let widx = ((k * p.Cin + kin) * p.KH + dy) * p.KW + dx; c = c + Wt[widx] * rho(Xin[xidx]); } } } // Top-down if (p.has_topdown_type == 1u) { // Dense next: Wnxt is [O × map_size], Unxt is [B × O] var td : f32 = 0.0; let this_flat = k * p.Hout * p.Wout + yo * p.Wout + xo; for (var n: u32 = 0u; n < p.nxt_O; n = n + 1u) { td = td + Wnxt[n * map_size + this_flat] * rho(Unxt[b * p.nxt_O + n]); } c = c + p.gamma * td; } else if (p.has_topdown_type == 2u) { // Conv next: TRANSPOSED CONV (deconv). // Next layer output u[b, k_nxt, yo_nxt, xo_nxt] receives contribution from // THIS layer position (k, yo, xo) via kernel offset (dy_nxt, dx_nxt) when // yo_nxt * nxt_stride + dy_nxt - nxt_pad == yo // xo_nxt * nxt_stride + dx_nxt - nxt_pad == xo // So for this position, the top-down sum reads back ALL next-layer outputs that read FROM here. // Iterate kernel offsets; for each, compute the next-layer position that would have used this one. var td : f32 = 0.0; // Compute spatial shape of next conv layer based on this layer's output shape: // Hnxt = floor((Hout + 2*nxt_pad - nxt_KH)/nxt_stride) + 1 let Hnxt : u32 = (p.Hout + 2u*p.nxt_pad - p.nxt_KH) / p.nxt_stride + 1u; let Wnxt_s : u32 = (p.Wout + 2u*p.nxt_pad - p.nxt_KW) / p.nxt_stride + 1u; let nxt_map_size = p.nxt_Cnxt * Hnxt * Wnxt_s; for (var k_nxt: u32 = 0u; k_nxt < p.nxt_Cnxt; k_nxt = k_nxt + 1u) { for (var dy_nxt: u32 = 0u; dy_nxt < p.nxt_KH; dy_nxt = dy_nxt + 1u) { // yo_nxt_s = (yo + nxt_pad - dy_nxt). Must be divisible by nxt_stride and in [0, Hnxt). let yo_nxt_s = i32(yo) + i32(p.nxt_pad) - i32(dy_nxt); if (yo_nxt_s < 0) { continue; } let yo_nxt_u = u32(yo_nxt_s); if (yo_nxt_u % p.nxt_stride != 0u) { continue; } let yo_nxt = yo_nxt_u / p.nxt_stride; if (yo_nxt >= Hnxt) { continue; } for (var dx_nxt: u32 = 0u; dx_nxt < p.nxt_KW; dx_nxt = dx_nxt + 1u) { let xo_nxt_s = i32(xo) + i32(p.nxt_pad) - i32(dx_nxt); if (xo_nxt_s < 0) { continue; } let xo_nxt_u = u32(xo_nxt_s); if (xo_nxt_u % p.nxt_stride != 0u) { continue; } let xo_nxt = xo_nxt_u / p.nxt_stride; if (xo_nxt >= Wnxt_s) { continue; } // Kernel weight: W[k_nxt, this_kin=k, dy_nxt, dx_nxt] let widx_nxt = ((k_nxt * p.Cout + k) * p.nxt_KH + dy_nxt) * p.nxt_KW + dx_nxt; let uidx_nxt = b * nxt_map_size + k_nxt * Hnxt * Wnxt_s + yo_nxt * Wnxt_s + xo_nxt; td = td + Wnxt[widx_nxt] * rho(Unxt[uidx_nxt]); } } } c = c + p.gamma * td; } // SI-5: dense → conv skip top-down (in ADDITION to existing chain top-down). if (p.has_skip != 0u) { var td_skip : f32 = 0.0; let this_flat = k * p.Hout * p.Wout + yo * p.Wout + xo; for (var n: u32 = 0u; n < p.skip_O; n = n + 1u) { td_skip = td_skip + Wskip[n * map_size + this_flat] * rho(Uskip[b * p.skip_O + n]); } c = c + p.skip_gamma * td_skip; } // Tier A — pre-σ drive clamp (active iff clamp_hi > clamp_lo) if (p.clamp_hi > p.clamp_lo) { c = clamp(c, p.clamp_lo, p.clamp_hi); } let idx = b * map_size + k * p.Hout * p.Wout + yo * p.Wout + xo; let u_old = Uh[idx]; let p_spike = rho(c); if (p.spike_mode != 0u) { // MSMEN-MVT: stochastic spike sampling. Running mean of binary spikes is the // unbiased estimator of σ(c) — matches deterministic in expectation, adds variance // per-iter that decorrelates samples (M-sample ensemble at inference). // n = iter_index + 1 (avoid /0 on first iter) let s_t = select(0.0, 1.0, pcg_unit(b, k * p.Hout * p.Wout + yo * p.Wout + xo, p.iter_index, p.iter_seed_base) < p_spike); let n = f32(p.iter_index + 1u); Uh[idx] = (1.0 - 1.0/n) * u_old + (1.0/n) * s_t; } else { // Deterministic adaptive σ update — v07 default behavior. let drive = -u_old + p_spike; Uh[idx] = u_old + Tau[k] * drive; } } @compute @workgroup_size(64) fn init_state(@builtin(global_invocation_id) gid: vec3) { let stride = 65535u * 64u; let g = gid.y * stride + gid.x; let n = p.B * p.Cout * p.Hout * p.Wout; if (g < n) { Uh[g] = 0.1; } } `; // Dense layer relax — supports BOTH: // * Output dense (last in dense chain): has_target=1, gets +/-β nudge in plus/minus phase // * Hidden dense (Tier E — heterogeneous trainer): has_topdown=1, reads next-dense via Wnxt // Wnxt layout: [Nnxt x No]; if has_topdown=0 the binding can be a dummy buffer. const WGSL_DENSE_OUT_MULTI = ` struct DP { B: u32, Ni: u32, No: u32, Nnxt: u32, dt: f32, beta: f32, gamma: f32, _p2: f32, has_target: u32, has_topdown: u32, _p4: u32, _p5: u32, }; @group(0) @binding(0) var p : DP; @group(0) @binding(1) var Xin : array; // [B*Ni] @group(0) @binding(2) var Wt : array; // [No*Ni] @group(0) @binding(3) var Bs : array; // [No] @group(0) @binding(4) var Wnxt: array; // [Nnxt*No] top-down weights (dummy if has_topdown=0) @group(0) @binding(5) var Unxt: array; // [B*Nnxt] next-layer state (dummy if has_topdown=0) @group(0) @binding(6) var Tgt : array; // [B*No] @group(0) @binding(7) var Uo : array; // [B*No] fn sg(u: f32) -> f32 { return 1.0 / (1.0 + exp(-4.0 * (u - 0.5))); } fn rho(u: f32) -> f32 { return sg(u); } @compute @workgroup_size(64, 1) fn dense_pass(@builtin(global_invocation_id) gid: vec3) { let b = gid.y; let i = gid.x; if (b >= p.B || i >= p.No) { return; } var c : f32 = Bs[i]; for (var j: u32 = 0u; j < p.Ni; j = j + 1u) { c = c + Wt[i * p.Ni + j] * rho(Xin[b * p.Ni + j]); } if (p.has_topdown != 0u) { var td : f32 = 0.0; for (var k: u32 = 0u; k < p.Nnxt; k = k + 1u) { td = td + Wnxt[k * p.No + i] * rho(Unxt[b * p.Nnxt + k]); } c = c + p.gamma * td; } let idx = b * p.No + i; let u_old = Uo[idx]; var drive : f32 = -u_old + sg(c); if (p.has_target != 0u && p.beta != 0.0) { drive = drive + p.beta * (Tgt[idx] - u_old); } Uo[idx] = u_old + p.dt * drive; } @compute @workgroup_size(64) fn init_state_out(@builtin(global_invocation_id) gid: vec3) { let g = gid.x; let n = p.B * p.No; if (g < n) { Uo[g] = 0.1; } } `; // Gradient kernels per layer (conv & dense) — identical to single-conv lib. const WGSL_GRAD_CONV_MULTI = ` struct CGP { B: u32, Cin: u32, Cout: u32, H: u32, W: u32, Hout: u32, Wout: u32, KH: u32, KW: u32, stride: u32, pad: u32, _p0: u32, two_beta: f32, _p1: f32, _p2: f32, _p3: f32, }; @group(0) @binding(0) var p : CGP; @group(0) @binding(1) var Xp : array; @group(0) @binding(2) var Xm : array; @group(0) @binding(3) var Up : array; @group(0) @binding(4) var Um : array; @group(0) @binding(5) var R : array; @group(0) @binding(6) var gW : array; @group(0) @binding(7) var gB : array; fn sg(u: f32) -> f32 { return 1.0 / (1.0 + exp(-4.0 * (u - 0.5))); } fn rho(u: f32) -> f32 { return sg(u); } @compute @workgroup_size(8, 8, 1) fn grad_W_conv(@builtin(global_invocation_id) gid: vec3) { let dx = gid.x; let dy = gid.y; let kk = gid.z; if (dx >= p.KW || dy >= p.KH) { return; } let kout = kk / p.Cin; let kin = kk % p.Cin; if (kout >= p.Cout) { return; } let img_size = p.Cin * p.H * p.W; let map_size = p.Cout * p.Hout * p.Wout; var acc : f32 = 0.0; for (var b: u32 = 0u; b < p.B; b = b + 1u) { let rb = R[b]; for (var yo: u32 = 0u; yo < p.Hout; yo = yo + 1u) { let iy_s = i32(yo * p.stride + dy) - i32(p.pad); if (iy_s < 0 || iy_s >= i32(p.H)) { continue; } let iy = u32(iy_s); for (var xo: u32 = 0u; xo < p.Wout; xo = xo + 1u) { let ix_s = i32(xo * p.stride + dx) - i32(p.pad); if (ix_s < 0 || ix_s >= i32(p.W)) { continue; } let ix = u32(ix_s); let u_flat = b * map_size + kout * p.Hout * p.Wout + yo * p.Wout + xo; let x_flat = b * img_size + kin * p.H * p.W + iy * p.W + ix; acc = acc + rb * (rho(Up[u_flat]) * rho(Xp[x_flat]) - rho(Um[u_flat]) * rho(Xm[x_flat])); } } } let widx = ((kout * p.Cin + kin) * p.KH + dy) * p.KW + dx; gW[widx] = acc / p.two_beta; } @compute @workgroup_size(64) fn grad_B_conv(@builtin(global_invocation_id) gid: vec3) { let kout = gid.x; if (kout >= p.Cout) { return; } let map_size = p.Cout * p.Hout * p.Wout; var acc : f32 = 0.0; for (var b: u32 = 0u; b < p.B; b = b + 1u) { let rb = R[b]; for (var yo: u32 = 0u; yo < p.Hout; yo = yo + 1u) { for (var xo: u32 = 0u; xo < p.Wout; xo = xo + 1u) { let u_flat = b * map_size + kout * p.Hout * p.Wout + yo * p.Wout + xo; acc = acc + rb * (rho(Up[u_flat]) - rho(Um[u_flat])); } } } gB[kout] = acc / p.two_beta; } `; const WGSL_GRAD_DENSE_MULTI = ` struct DGP { B: u32, Ni: u32, No: u32, _p0: u32, two_beta: f32, _p1: f32, _p2: f32, _p3: f32, }; @group(0) @binding(0) var p : DGP; @group(0) @binding(1) var Xp : array; @group(0) @binding(2) var Xm : array; @group(0) @binding(3) var Up : array; @group(0) @binding(4) var Um : array; @group(0) @binding(5) var R : array; @group(0) @binding(6) var gW : array; @group(0) @binding(7) var gB : array; fn sg(u: f32) -> f32 { return 1.0 / (1.0 + exp(-4.0 * (u - 0.5))); } fn rho(u: f32) -> f32 { return sg(u); } @compute @workgroup_size(8, 8) fn grad_W_dense(@builtin(global_invocation_id) gid: vec3) { let i = gid.y; let j = gid.x; if (i >= p.No || j >= p.Ni) { return; } var acc : f32 = 0.0; for (var b: u32 = 0u; b < p.B; b = b + 1u) { let rb = R[b]; let ip = rho(Up[b * p.No + i]); let im = rho(Um[b * p.No + i]); let jp = rho(Xp[b * p.Ni + j]); let jm = rho(Xm[b * p.Ni + j]); acc = acc + rb * (ip * jp - im * jm); } gW[i * p.Ni + j] = acc / p.two_beta; } @compute @workgroup_size(64) fn grad_B_dense(@builtin(global_invocation_id) gid: vec3) { let i = gid.x; if (i >= p.No) { return; } var acc : f32 = 0.0; for (var b: u32 = 0u; b < p.B; b = b + 1u) { let rb = R[b]; acc = acc + rb * (rho(Up[b * p.No + i]) - rho(Um[b * p.No + i])); } gB[i] = acc / p.two_beta; } `; // Reward computation — identical to v03. const WGSL_AUX_MULTI = ` struct AP { B: u32, O: u32, _p0: u32, _p1: u32, _p2: f32, _p3: f32, _p4: f32, _p5: f32, }; @group(0) @binding(0) var p : AP; @group(0) @binding(1) var UoF : array; @group(0) @binding(2) var Tgt : array; @group(0) @binding(3) var R : array; fn sg(u: f32) -> f32 { return 1.0 / (1.0 + exp(-4.0 * (u - 0.5))); } @compute @workgroup_size(64) fn compute_reward(@builtin(global_invocation_id) gid: vec3) { let b = gid.x; if (b >= p.B) { return; } var loss : f32 = 0.0; let off = b * p.O; for (var i: u32 = 0u; i < p.O; i = i + 1u) { let d = sg(UoF[off + i]) - Tgt[off + i]; loss = loss + d * d; } var r : f32 = loss / 0.4; if (r > 1.0) { r = 1.0; } R[b] = 0.1 + 0.9 * r; } `; export class GPUTrainerConvMulti { // convCfgs: array of {Cin, Cout, KH, KW, stride, pad, H, W} — first entry's Cin/H/W is input image, // subsequent entries' Cin/H/W must equal previous entry's Cout/Hout/Wout (validated in constructor). // denseSize OR denseSizes: // - denseSize (scalar): number of OUTPUT classes O (backward-compat with v03 single-dense) // - denseSizes (array): [hiddenSize1, hiddenSize2, ..., O] — Tier E heterogeneous trainer // conv stack feeds first dense; each dense feeds the next via top-down. // Last dense receives the ±β target nudge. // B: batch size constructor({dev, convCfgs, denseSize, denseSizes, B}){ this.dev = dev; if(!Array.isArray(convCfgs) || convCfgs.length < 1) throw new Error('convCfgs must be non-empty array'); // Resolve denseSizes (Tier E): if scalar denseSize passed, wrap as single-element array. if(denseSizes !== undefined){ if(!Array.isArray(denseSizes) || denseSizes.length < 1) throw new Error('denseSizes must be non-empty array'); this.denseSizes = denseSizes.slice(); } else if(denseSize !== undefined){ this.denseSizes = [denseSize]; } else { throw new Error('must pass denseSize (scalar) or denseSizes (array)'); } this.D = this.denseSizes.length; // number of dense layers this.O = this.denseSizes[this.D-1]; // output classes = last dense size this.cfgs = convCfgs.map(c => ({...c})); // shallow-copy entries this.N = this.cfgs.length; this.B = B; // Compute per-layer Hout/Wout and verify chain consistency let prevC = null, prevH = null, prevW = null; for(let l=0; l0: denseSizes[d-1] // No: denseSizes[d] this.bufWdense = []; this.bufBdense = []; this.bufGWdense = []; this.bufGBdense = []; this.rbGWdense = []; this.rbGBdense = []; this.bufUout = [[],[],[]]; // bufUout[phase][d] is dense layer d's state for(let d=0; d({binding:i, visibility:GPUShaderStage.COMPUTE, buffer:{type:'read-only-storage'}}); const sRW = (i)=>({binding:i, visibility:GPUShaderStage.COMPUTE, buffer:{type:'storage'}}); const uN = (i)=>({binding:i, visibility:GPUShaderStage.COMPUTE, buffer:{type:'uniform'}}); const modConv = dev.createShaderModule({code: WGSL_CONV_RELAX_MULTI}); this.bglConv = dev.createBindGroupLayout({entries:[uN(0), sR(1), sR(2), sR(3), sR(4), sRW(5), sR(6), sR(7), sR(8), sR(9)]}); this.plConv = dev.createPipelineLayout({bindGroupLayouts:[this.bglConv]}); this.pipeConv = dev.createComputePipeline({layout:this.plConv, compute:{module:modConv, entryPoint:'conv_pass'}}); this.pipeInitConv = dev.createComputePipeline({layout:this.plConv, compute:{module:modConv, entryPoint:'init_state'}}); const modDense = dev.createShaderModule({code: WGSL_DENSE_OUT_MULTI}); this.bglDense = dev.createBindGroupLayout({entries:[uN(0), sR(1), sR(2), sR(3), sR(4), sR(5), sR(6), sRW(7)]}); this.plDense = dev.createPipelineLayout({bindGroupLayouts:[this.bglDense]}); this.pipeDense = dev.createComputePipeline({layout:this.plDense, compute:{module:modDense, entryPoint:'dense_pass'}}); this.pipeInitDense = dev.createComputePipeline({layout:this.plDense, compute:{module:modDense, entryPoint:'init_state_out'}}); const modGC = dev.createShaderModule({code: WGSL_GRAD_CONV_MULTI}); this.bglGC = dev.createBindGroupLayout({entries:[uN(0), sR(1), sR(2), sR(3), sR(4), sR(5), sRW(6), sRW(7)]}); this.plGC = dev.createPipelineLayout({bindGroupLayouts:[this.bglGC]}); this.pipeGWconv = dev.createComputePipeline({layout:this.plGC, compute:{module:modGC, entryPoint:'grad_W_conv'}}); this.pipeGBconv = dev.createComputePipeline({layout:this.plGC, compute:{module:modGC, entryPoint:'grad_B_conv'}}); const modGD = dev.createShaderModule({code: WGSL_GRAD_DENSE_MULTI}); this.bglGD = dev.createBindGroupLayout({entries:[uN(0), sR(1), sR(2), sR(3), sR(4), sR(5), sRW(6), sRW(7)]}); this.plGD = dev.createPipelineLayout({bindGroupLayouts:[this.bglGD]}); this.pipeGWdense = dev.createComputePipeline({layout:this.plGD, compute:{module:modGD, entryPoint:'grad_W_dense'}}); this.pipeGBdense = dev.createComputePipeline({layout:this.plGD, compute:{module:modGD, entryPoint:'grad_B_dense'}}); // SI-5: skip W gradient uses the SAME grad_W_dense kernel — outer product of (last dense state) × (conv hidden). // Bind: Xp = conv_l hidden plus phase, Xm = conv_l hidden minus phase, Up = dense_last plus, Um = dense_last minus. // Output: gW with shape [denseSizes[D-1] × convFlat_l] — matches Wskip[l] layout. const modAux = dev.createShaderModule({code: WGSL_AUX_MULTI}); this.bglAux = dev.createBindGroupLayout({entries:[uN(0), sR(1), sR(2), sRW(3)]}); this.plAux = dev.createPipelineLayout({bindGroupLayouts:[this.bglAux]}); this.pipeReward = dev.createComputePipeline({layout:this.plAux, compute:{module:modAux, entryPoint:'compute_reward'}}); // ---- bind groups ---- // Conv per (layer, phase). Inputs depend on layer index. // layer 0: Xin = bufXin // layer l>0: Xin = bufUconv[phase][l-1] (previous layer's U-state, post-σ via rho()) // Wnxt: bufWdense if last conv (top-down type=1), else bufWconv[l+1] (type=2) // Unxt: bufUout[phase] if last conv, else bufUconv[phase][l+1] this.bgConv = []; this.bgInitConv = []; for(let l=0; l0 reads previous dense state const Xin = (d === 0) ? this.bufUconv[p][N-1] : this.bufUout[p][d-1]; // top-down: hidden dense (not last) reads next dense's W and U; last has none (uses target nudge) const Wnxt = isLastD ? this.bufDummyR : this.bufWdense[d+1]; const Unxt = isLastD ? this.bufDummyR : this.bufUout[p][d+1]; this.bgDense[p].push(dev.createBindGroup({layout:this.bglDense, entries:[ {binding:0, resource:{buffer:this.bufP_dense[p][d]}}, {binding:1, resource:{buffer:Xin}}, {binding:2, resource:{buffer:this.bufWdense[d]}}, {binding:3, resource:{buffer:this.bufBdense[d]}}, {binding:4, resource:{buffer:Wnxt}}, {binding:5, resource:{buffer:Unxt}}, {binding:6, resource:{buffer:this.bufTgt}}, {binding:7, resource:{buffer:this.bufUout[p][d]}}, ]})); this.bgInitDense[p].push(dev.createBindGroup({layout:this.bglDense, entries:[ {binding:0, resource:{buffer:this.bufP_init_dense[p][d]}}, {binding:1, resource:{buffer:Xin}}, {binding:2, resource:{buffer:this.bufWdense[d]}}, {binding:3, resource:{buffer:this.bufBdense[d]}}, {binding:4, resource:{buffer:Wnxt}}, {binding:5, resource:{buffer:Unxt}}, {binding:6, resource:{buffer:this.bufTgt}}, {binding:7, resource:{buffer:this.bufUout[p][d]}}, ]})); } } // Grad bind groups per conv layer: // pre_p/pre_m = layer's INPUT (Xin if l=0, else previous layer's U-plus/minus phase) // post_p/post_m = THIS layer's U-plus/minus phase this.bgGC = []; for(let l=0; l