File size: 44,184 Bytes
d035fbd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
// Multi-hidden-layer WebGPU EqProp trainer.
// sizes = [D, H1, H2, ..., Hk, O]
// Modes:
//   0 = adaptive (σ everywhere, u-nudge at output)
//   1 = fhn      (clip everywhere, ρ-nudge at output)
//
// Key design:
//   * ONE generic compute pipeline `pass_layer` updates any layer using uniforms (ni, no, nxt, ...).
//   * 3 phases (free / plus / minus): each has its own per-layer state buffer + per-layer bind group.
//   * Uniform buffers are pre-written once per pass (one per (phase, layer) pair), so the encoder
//     records dispatches that each pick up the right uniforms.
//   * Gradient kernels are also generic — per layer transition, compute outer product reduction over batch.

import { orth as orthCPU } from './eqprop_lib.js';

// ----- WGSL: relax (one kernel handles any layer) -----
const WGSL_RELAX = `
struct P {
  ni : u32, no : u32, nxt : u32, B : u32,
  dt : f32, beta : f32, gamma : f32, mode : f32,    // mode: 0=adaptive σ, 1=fhn(clip+cubic), 2=prism (soft-clip via softplus)
  has_topdown : u32, has_target : u32, noise_scale : f32, iter_seed : u32,
  // sEqProp: noise_scale > 0 injects per-iter per-neuron Gaussian-ish noise into drive.
  // Bio-faithful (real synapses are stochastic). At test, run M passes & average outputs.
  clamp_lo : f32, clamp_hi : f32, _p_t1 : f32, _p_t2 : f32,
  // Tier A — pre-σ drive clamp (algorithmic, uniform-driven). Bounds the pre-activation
  // value c before σ(c) to prevent saturation runaway. When clamp_hi <= clamp_lo the
  // kernel treats it as DISABLED (no-op). Default in caller = clamp_lo=clamp_hi=0 → disabled.
};
@group(0) @binding(0) var<uniform> p : P;
@group(0) @binding(1) var<storage, read>        Win  : array<f32>;
@group(0) @binding(2) var<storage, read>        W0   : array<f32>;  // [no x ni]
@group(0) @binding(3) var<storage, read>        b0   : array<f32>;  // [no]
@group(0) @binding(4) var<storage, read>        W1   : array<f32>;  // [nxt x no]   (top-down)
@group(0) @binding(5) var<storage, read_write>  Uh   : array<f32>;  // [B*no]
@group(0) @binding(6) var<storage, read>        Uo   : array<f32>;  // [B*nxt]
@group(0) @binding(7) var<storage, read>        Tgt  : array<f32>;  // [B*no]
// HPSN: heterogeneous time constants — per-neuron multiplier on drive integration.
// Tau[i] replaces the global p.dt. Constant Tau[i]=p.dt → behavior identical to scalar-dt EqProp.
// Sampled from Uniform[tau_min, tau_max] → diverse temporal scales like real cortical neurons.
@group(0) @binding(8) var<storage, read>        Tau  : array<f32>;  // [no]

const A1 : f32 = 0.07407407407407407;
const PRISM_K : f32 = 10.0;   // sharpness; higher k → harder clip

// PCG-style cheap hash → uniform [0,1). Per-thread, per-iter, per-neuron stochasticity.
fn pcg_hash(seed_in: u32) -> u32 {
  var state : u32 = seed_in * 747796405u + 2891336453u;
  let word : u32 = ((state >> ((state >> 28u) + 4u)) ^ state) * 277803737u;
  return (word >> 22u) ^ word;
}
fn unif_noise(b: u32, i: u32, t: u32) -> f32 {
  // Triangular distribution (sum of 2 uniforms - 1) ≈ Gaussian-ish, mean=0, variance=1/6.
  let h1 = pcg_hash(b * 65537u + i * 257u + t * 31u);
  let h2 = pcg_hash(b * 31337u + i * 1031u + t * 17u + 12345u);
  let u1 = f32(h1) / 4294967296.0;
  let u2 = f32(h2) / 4294967296.0;
  return (u1 + u2) - 1.0;   // range [-1, 1], roughly triangular
}
fn sg(u: f32) -> f32 { return 1.0 / (1.0 + exp(-4.0 * (u - 0.5))); }
fn fhn_rho(u: f32) -> f32 { return clamp(u, 0.0, 1.0); }
fn fhn_rho_p(u: f32) -> f32 { return select(0.0, 1.0, u > 0.0 && u < 1.0); }
fn fhn_f(u: f32) -> f32 { return A1 * u - u*u*u; }
// PRISM activation: ρ(u) = (softplus(k·u) - softplus(k·(u-1))) / k
// Smooth approximation of clip(u,0,1). Derivative: σ(k·u) - σ(k·(u-1)).
// "Prism" = splits drive into a smooth-yet-saturating activation with gradient flow on both sides.
fn softplus(x: f32) -> f32 { return select(x + log(1.0 + exp(-x)), log(1.0 + exp(x)), x <= 0.0); }
fn prism_rho(u: f32) -> f32 {
  return (softplus(PRISM_K * u) - softplus(PRISM_K * (u - 1.0))) / PRISM_K;
}
fn prism_rho_p(u: f32) -> f32 {
  let k = PRISM_K;
  return 1.0/(1.0+exp(-k*u)) - 1.0/(1.0+exp(-k*(u-1.0)));
}
fn rho(u: f32) -> f32 {
  if (p.mode > 1.5) { return prism_rho(u); }
  if (p.mode > 0.5) { return fhn_rho(u); }
  return sg(u);
}

@compute @workgroup_size(8, 8) fn pass_layer(@builtin(global_invocation_id) gid: vec3<u32>) {
  let b = gid.y; let i = gid.x;
  if (b >= p.B || i >= p.no) { return; }

  // bottom-up: c = b0[i] + sum_j W0[i,j] * rho(Win[b,j])
  var c : f32 = b0[i];
  let row0 = i * p.ni;
  let xoff = b * p.ni;
  for (var j: u32 = 0u; j < p.ni; j = j + 1u) {
    c = c + W0[row0 + j] * rho(Win[xoff + j]);
  }
  // top-down: gamma * sum_k W1[k,i] * rho(Uo[b,k])  (if next layer exists)
  if (p.has_topdown != 0u) {
    var td : f32 = 0.0;
    let uo_off = b * p.nxt;
    for (var k: u32 = 0u; k < p.nxt; k = k + 1u) {
      td = td + W1[k * p.no + i] * rho(Uo[uo_off + k]);
    }
    if (p.mode > 0.5) { c = c + td; } else { c = c + p.gamma * td; }
  }

  // Tier A — pre-σ drive clamp (algorithmic). Active iff clamp_hi > clamp_lo.
  if (p.clamp_hi > p.clamp_lo) {
    c = clamp(c, p.clamp_lo, p.clamp_hi);
  }

  let idx = b * p.no + i;
  let u_old = Uh[idx];
  // sEqProp noise: per-(b, i, iter_seed) triangular noise added to drive. Zero by default.
  let noise = select(0.0, p.noise_scale * unif_noise(b, i, p.iter_seed), p.noise_scale > 0.0);
  var u_new : f32;
  if (p.mode > 1.5) {
    // PRISM: u̇ = ρ'(u)·c + (linear pull) ; ρ-nudge for output. Smooth saturating dynamics.
    var drive : f32 = prism_rho_p(u_old) * c - 0.1 * (u_old - 0.5) + noise;
    if (p.has_target != 0u && p.beta != 0.0) {
      drive = drive + p.beta * (Tgt[idx] - prism_rho(u_old));
    }
    u_new = u_old + Tau[i] * drive;
    u_new = clamp(u_new, -0.3, 1.3);
  } else if (p.mode > 0.5) {
    // FHN
    var drive : f32 = fhn_rho_p(u_old) * c + fhn_f(u_old) + noise;
    if (p.has_target != 0u && p.beta != 0.0) {
      drive = drive + p.beta * (Tgt[idx] - fhn_rho(u_old));
    }
    u_new = u_old + Tau[i] * drive;
    u_new = clamp(u_new, -0.2, 1.2);
  } else {
    // Adaptive
    var drive : f32 = -u_old + sg(c) + noise;
    if (p.has_target != 0u && p.beta != 0.0) {
      drive = drive + p.beta * (Tgt[idx] - u_old);
    }
    u_new = u_old + Tau[i] * drive;
  }
  Uh[idx] = u_new;
}

// 2D dispatch to handle big buffers (B*no can exceed the 65535 per-dim workgroup limit).
@compute @workgroup_size(64) fn init_state(@builtin(global_invocation_id) gid: vec3<u32>) {
  let stride = 65535u * 64u;   // workgroups_per_X * threads_per_workgroup
  let g = gid.y * stride + gid.x;
  let n = p.B * p.no;
  if (g < n) { Uh[g] = 0.1; }
}
`;

// ----- WGSL: gradient (one kernel handles any layer transition) -----
const WGSL_GRAD = `
struct GP {
  ni : u32, no : u32, _pad : u32, B : u32,
  c : f32, two_beta : f32, mode_pre : f32, mode_post : f32,   // mode_pre/post: 0=σ, 1=clip, 2=identity
};
@group(0) @binding(0) var<uniform> p : GP;
@group(0) @binding(1) var<storage, read>        UpreP : array<f32>;  // [B*ni] - "input" layer state, plus phase
@group(0) @binding(2) var<storage, read>        UpreM : array<f32>;  // [B*ni] - minus
@group(0) @binding(3) var<storage, read>        UpostP: array<f32>;  // [B*no]
@group(0) @binding(4) var<storage, read>        UpostM: array<f32>;  // [B*no]
@group(0) @binding(5) var<storage, read>        R     : array<f32>;  // [B]
@group(0) @binding(6) var<storage, read_write>  gW    : array<f32>;  // [no*ni]
@group(0) @binding(7) var<storage, read_write>  gB    : array<f32>;  // [no]

fn sg(u: f32) -> f32 { return 1.0 / (1.0 + exp(-4.0 * (u - 0.5))); }
const PRISM_K2 : f32 = 10.0;
fn softplus2(x: f32) -> f32 { return select(x + log(1.0 + exp(-x)), log(1.0 + exp(x)), x <= 0.0); }
fn prism_rho_g(u: f32) -> f32 { return (softplus2(PRISM_K2*u) - softplus2(PRISM_K2*(u-1.0))) / PRISM_K2; }
fn rho_mode(u: f32, m: f32) -> f32 {
  if (m > 2.5) { return u; }                       // identity (linear)
  if (m > 1.5) { return prism_rho_g(u); }          // prism soft-clip
  if (m > 0.5) { return clamp(u, 0.0, 1.0); }      // hard-clip (FHN)
  return sg(u);                                    // σ (adaptive)
}

@compute @workgroup_size(8, 8) fn grad_W(@builtin(global_invocation_id) gid: vec3<u32>) {
  let i = gid.y; let j = gid.x;
  if (i >= p.no || j >= p.ni) { return; }
  var acc : f32 = 0.0;
  for (var b: u32 = 0u; b < p.B; b = b + 1u) {
    let rh = R[b];
    let ip = rho_mode(UpostP[b * p.no + i], p.mode_post);
    let im = rho_mode(UpostM[b * p.no + i], p.mode_post);
    let jp = rho_mode(UpreP[b * p.ni + j], p.mode_pre);
    let jm = rho_mode(UpreM[b * p.ni + j], p.mode_pre);
    acc = acc + rh * (ip * jp - im * jm);
  }
  gW[i * p.ni + j] = acc / p.two_beta;
}
@compute @workgroup_size(64) fn grad_B(@builtin(global_invocation_id) gid: vec3<u32>) {
  let i = gid.x;
  if (i >= p.no) { return; }
  var acc : f32 = 0.0;
  for (var b: u32 = 0u; b < p.B; b = b + 1u) {
    let rh = R[b];
    let ip = rho_mode(UpostP[b * p.no + i], p.mode_post);
    let im = rho_mode(UpostM[b * p.no + i], p.mode_post);
    acc = acc + rh * (ip - im);
  }
  gB[i] = acc / p.two_beta;
}
`;

// ----- WGSL: reward + adaptation (depends on output layer state) -----
const WGSL_AUX = `
struct AP {
  B : u32, O : u32, H_max : u32, n_hidden : u32,
  c : f32, mode : f32, _p0 : f32, _p1 : f32,
};
@group(0) @binding(0) var<uniform> p : AP;
@group(0) @binding(1) var<storage, read>        UoF : array<f32>;
@group(0) @binding(2) var<storage, read>        Tgt : array<f32>;
@group(0) @binding(3) var<storage, read_write>  R   : array<f32>;
// adaptation buffers (variable size; we pass single layer at a time via separate bind groups)
@group(0) @binding(4) var<storage, read>        Uf  : array<f32>;
@group(0) @binding(5) var<storage, read_write>  Up  : array<f32>;
@group(0) @binding(6) var<storage, read_write>  Um  : array<f32>;

fn sg(u: f32) -> f32 { return 1.0 / (1.0 + exp(-4.0 * (u - 0.5))); }
fn rho_out(u: f32) -> f32 {
  if (p.mode > 0.5) { return clamp(u, 0.0, 1.0); }
  return sg(u);
}

@compute @workgroup_size(64) fn compute_reward(@builtin(global_invocation_id) gid: vec3<u32>) {
  let b = gid.x;
  if (b >= p.B) { return; }
  var loss : f32 = 0.0;
  let off = b * p.O;
  for (var i: u32 = 0u; i < p.O; i = i + 1u) {
    let d = rho_out(UoF[off + i]) - Tgt[off + i];
    loss = loss + d * d;
  }
  let escale : f32 = 0.4;
  let rmin   : f32 = 0.1;
  var r : f32 = loss / escale;
  if (r > 1.0) { r = 1.0; }
  R[b] = rmin + (1.0 - rmin) * r;
}

// Adjusted Adaptation per layer: Up,Um ← (1-c)*Up + c*Uf. 2D dispatch safe for large buffers.
@compute @workgroup_size(64) fn adapt_layer(@builtin(global_invocation_id) gid: vec3<u32>) {
  let stride = 65535u * 64u;
  let g = gid.y * stride + gid.x;
  if (g >= arrayLength(&Uf)) { return; }
  let f = Uf[g];
  Up[g] = (1.0 - p.c) * Up[g] + p.c * f;
  Um[g] = (1.0 - p.c) * Um[g] + p.c * f;
}
`;

// ----- JS: trainer class -----
export async function makeGPUDeep({powerPreference='high-performance'}={}){
  if(!navigator.gpu) throw new Error('no webgpu');
  const adapter = await navigator.gpu.requestAdapter({powerPreference});
  if(!adapter) throw new Error('no adapter');
  const want = {};
  const tryKeys = ['maxStorageBuffersPerShaderStage','maxBufferSize','maxStorageBufferBindingSize',
    'maxComputeInvocationsPerWorkgroup','maxComputeWorkgroupSizeX','maxComputeWorkgroupStorageSize','maxBindGroups'];
  for(const k of tryKeys){ const v=adapter.limits[k]; if(typeof v==='number') want[k]=v; }
  const dev = await adapter.requestDevice({requiredLimits: want});
  return {adapter, dev, info: adapter.info||{}};
}

const PHASE_F = 0, PHASE_P = 1, PHASE_M = 2;

export class GPUTrainerDeep {
  // sizes: [D, H1, H2, ..., Hk, O] — len L+1; L = number of weight matrices = sizes.length-1
  // mode: 'adaptive' | 'fhn' | 'prism'
  // driveClampLo, driveClampHi: Tier A — pre-σ drive clamp; ACTIVE iff hi > lo. Default 0,0 = disabled.
  constructor({dev, sizes, B, mode='adaptive', gamma=0.6, hpsnTauMin=0, hpsnTauMax=0, hpsnSeed=42, driveClampLo=0, driveClampHi=0}={}){
    this.dev = dev; this.sizes = sizes;
    this.L = sizes.length - 1;   // number of weight matrices (transitions)
    this.B = B; this.O = sizes[sizes.length-1];
    this.mode = mode;
    this.modeFlag = (mode==='prism') ? 2.0 : (mode==='fhn' ? 1.0 : 0.0);
    this.gamma = gamma;
    this.hpsnTauMin = hpsnTauMin;
    this.hpsnTauMax = hpsnTauMax;
    this.hpsnSeed = hpsnSeed;
    this.useHPSN = (hpsnTauMax > hpsnTauMin && hpsnTauMin > 0);
    this.driveClampLo = driveClampLo;
    this.driveClampHi = driveClampHi;
    this._build();
    // Initialize Tau buffers — either constant=0.7 (backward compat) or per-neuron Uniform[hpsnTauMin, hpsnTauMax].
    if(this.useHPSN){
      this.setAllTau(0.7, hpsnTauMin, hpsnTauMax, hpsnSeed);
    } else {
      this.setAllTau(0.7);
    }
  }
  _F32buf(n, usage){
    if(!Number.isFinite(n) || n <= 0){
      console.error('BAD _F32buf size', {n, sizes:JSON.stringify(this.sizes), sizesArr:this.sizes, B:this.B, L:this.L, S0:this.sizes&&this.sizes[0], S0type:typeof (this.sizes&&this.sizes[0])});
      throw new Error('_F32buf called with non-finite n=' + n + ' sizes=' + JSON.stringify(this.sizes));
    }
    return this.dev.createBuffer({size:Math.max(4,n*4), usage});
  }
  _build(){
    const dev = this.dev, S = this.sizes, B = this.B, L = this.L;
    const RW  = GPUBufferUsage.STORAGE|GPUBufferUsage.COPY_SRC|GPUBufferUsage.COPY_DST;
    const R   = GPUBufferUsage.STORAGE|GPUBufferUsage.COPY_DST|GPUBufferUsage.COPY_SRC;
    const UNI = GPUBufferUsage.UNIFORM|GPUBufferUsage.COPY_DST;
    const RDS = GPUBufferUsage.COPY_DST|GPUBufferUsage.MAP_READ;

    // input + target (shared across phases)
    this.bufWin = this._F32buf(B * S[0], R);
    this.bufTgt = this._F32buf(B * S[L], R);
    // weights & biases (one per transition)
    this.bufW = []; this.bufB = [];
    for(let l=0; l<L; l++){
      this.bufW.push(this._F32buf(S[l+1]*S[l], R));
      this.bufB.push(this._F32buf(S[l+1], R));
    }
    // HPSN: per-layer Tau buffer of size [no]. Default = uniform scalar dt (backward compat).
    // User can call setHeterogeneousTau(layer, tauMin, tauMax) to enable HPSN per layer.
    this.bufTau = [];
    for(let l=0; l<L; l++){ this.bufTau.push(this._F32buf(S[l+1], R)); }
    // state buffers: for each of 3 phases, L state buffers (one per non-input layer)
    this.bufU = [[],[],[]];   // bufU[phase][l] is layer l+1's state (l=0..L-1, sizes S[1..L])
    for(let phase=0; phase<3; phase++){
      for(let l=1; l<=L; l++){
        this.bufU[phase].push(this._F32buf(B * S[l], RW));
      }
    }
    // reward + dummies (need separate buffers for read-only vs writable slots to avoid aliasing).
    this.bufR        = this._F32buf(B, RW);
    this.bufDummyR   = this._F32buf(4, R);    // read-only dummy
    this.bufDummyRW1 = this._F32buf(4, RW);   // writable dummy slot 1
    this.bufDummyRW2 = this._F32buf(4, RW);   // writable dummy slot 2 (different buffer!)
    this.bufDummyRW3 = this._F32buf(4, RW);

    // gradient buffer (packed: all gW and gB together)
    const gSizes = []; let total=0;
    for(let l=0; l<L; l++){ gSizes.push({offW:total, sizW:S[l+1]*S[l], offB:total+S[l+1]*S[l], sizB:S[l+1], total:S[l+1]*S[l]+S[l+1]}); total += gSizes[l].total; }
    this.gOff = gSizes; this.gTotal = total;
    this.bufG = this._F32buf(total, RW);
    this.rbG  = dev.createBuffer({size: total*4, usage: RDS});
    // readback for output free-phase Uo (for accuracy/loss)
    this.rbUoF = dev.createBuffer({size: B * S[L] * 4, usage: RDS});

    // ---- pipelines ----
    // Relax pipeline (generic)
    const modR = dev.createShaderModule({code: WGSL_RELAX});
    const sR = (i)=>({binding:i, visibility:GPUShaderStage.COMPUTE, buffer:{type:'read-only-storage'}});
    const sRW= (i)=>({binding:i, visibility:GPUShaderStage.COMPUTE, buffer:{type:'storage'}});
    const uN = (i)=>({binding:i, visibility:GPUShaderStage.COMPUTE, buffer:{type:'uniform'}});
    this.bglR = dev.createBindGroupLayout({entries:[uN(0), sR(1), sR(2), sR(3), sR(4), sRW(5), sR(6), sR(7), sR(8)]});
    this.plR  = dev.createPipelineLayout({bindGroupLayouts:[this.bglR]});
    this.pipeRelax = dev.createComputePipeline({layout:this.plR, compute:{module:modR, entryPoint:'pass_layer'}});
    this.pipeInit  = dev.createComputePipeline({layout:this.plR, compute:{module:modR, entryPoint:'init_state'}});

    // Grad pipeline (generic)
    const modG = dev.createShaderModule({code: WGSL_GRAD});
    this.bglG = dev.createBindGroupLayout({entries:[uN(0), sR(1), sR(2), sR(3), sR(4), sR(5), sRW(6), sRW(7)]});
    this.plG  = dev.createPipelineLayout({bindGroupLayouts:[this.bglG]});
    this.pipeGW = dev.createComputePipeline({layout:this.plG, compute:{module:modG, entryPoint:'grad_W'}});
    this.pipeGB = dev.createComputePipeline({layout:this.plG, compute:{module:modG, entryPoint:'grad_B'}});

    // Aux pipeline (reward + adaptation)
    const modA = dev.createShaderModule({code: WGSL_AUX});
    this.bglA = dev.createBindGroupLayout({entries:[uN(0), sR(1), sR(2), sRW(3), sR(4), sRW(5), sRW(6)]});
    this.plA  = dev.createPipelineLayout({bindGroupLayouts:[this.bglA]});
    this.pipeReward = dev.createComputePipeline({layout:this.plA, compute:{module:modA, entryPoint:'compute_reward'}});
    this.pipeAdapt  = dev.createComputePipeline({layout:this.plA, compute:{module:modA, entryPoint:'adapt_layer'}});

    // ---- uniform buffers (one per (phase, layer)) ----
    // Each entry: 48 bytes (12 u32/f32 slots)
    this.bufP_relax = [];
    for(let phase=0; phase<3; phase++){
      this.bufP_relax.push([]);
      for(let l=1; l<=L; l++){
        this.bufP_relax[phase].push(dev.createBuffer({size: 64, usage: UNI}));
      }
    }
    // init uniform (one per layer, beta=0)
    this.bufP_init = [];
    for(let l=1; l<=L; l++) this.bufP_init.push(dev.createBuffer({size: 64, usage: UNI}));

    // Grad uniform (one per layer transition)
    this.bufP_grad = [];
    for(let l=0; l<L; l++) this.bufP_grad.push(dev.createBuffer({size: 32, usage: UNI}));

    // Aux uniforms: reward + per-layer adaptation
    this.bufP_rew = dev.createBuffer({size: 32, usage: UNI});
    this.bufP_adapt = []; for(let l=1; l<=L; l++) this.bufP_adapt.push(dev.createBuffer({size: 32, usage: UNI}));

    // ---- bind groups ----
    // Relax: per (phase, layer)
    this.bgR = [[],[],[]];
    for(let phase=0; phase<3; phase++){
      for(let l=1; l<=L; l++){
        // For layer l (1-indexed): Win = state[phase][l-2] if l>1 else bufWin
        //                         W0 = bufW[l-1], b0 = bufB[l-1]
        //                         W1 = bufW[l] (top-down weights to layer l+1), used if l<L
        //                         Uh = state[phase][l-1]
        //                         Uo = state[phase][l] (next layer), used if l<L
        //                         Tgt = bufTgt (used if l==L)
        const Win = (l===1) ? this.bufWin : this.bufU[phase][l-2];
        const W0  = this.bufW[l-1], b0 = this.bufB[l-1];
        const W1  = (l < L) ? this.bufW[l] : this.bufDummyR;
        const Uh  = this.bufU[phase][l-1];
        const Uo  = (l < L) ? this.bufU[phase][l] : this.bufDummyR;
        const Tgt = this.bufTgt;
        this.bgR[phase].push(dev.createBindGroup({layout: this.bglR, entries:[
          {binding:0, resource:{buffer:this.bufP_relax[phase][l-1]}},
          {binding:1, resource:{buffer:Win}},
          {binding:2, resource:{buffer:W0}},
          {binding:3, resource:{buffer:b0}},
          {binding:4, resource:{buffer:W1}},
          {binding:5, resource:{buffer:Uh}},
          {binding:6, resource:{buffer:Uo}},
          {binding:7, resource:{buffer:Tgt}},
          {binding:8, resource:{buffer:this.bufTau[l-1]}},
        ]}));
      }
    }
    // Init: per (phase, layer) — uses bufP_init (beta=0)
    this.bgInit = [[],[],[]];
    for(let phase=0; phase<3; phase++){
      for(let l=1; l<=L; l++){
        const Win = (l===1) ? this.bufWin : this.bufU[phase][l-2];
        const W0  = this.bufW[l-1], b0 = this.bufB[l-1];
        const W1  = (l < L) ? this.bufW[l] : this.bufDummyR;
        const Uh  = this.bufU[phase][l-1];
        const Uo  = (l < L) ? this.bufU[phase][l] : this.bufDummyR;
        const Tgt = this.bufTgt;
        this.bgInit[phase].push(dev.createBindGroup({layout: this.bglR, entries:[
          {binding:0, resource:{buffer:this.bufP_init[l-1]}},
          {binding:1, resource:{buffer:Win}},
          {binding:2, resource:{buffer:W0}},
          {binding:3, resource:{buffer:b0}},
          {binding:4, resource:{buffer:W1}},
          {binding:5, resource:{buffer:Uh}},
          {binding:6, resource:{buffer:Uo}},
          {binding:7, resource:{buffer:Tgt}},
          {binding:8, resource:{buffer:this.bufTau[l-1]}},
        ]}));
      }
    }
    // Grad: per layer transition. Note: gW and gB need offsets into bufG.
    // Approach: instead of bindings to subranges, we make a SEPARATE buffer per layer for gradients (simpler).
    // Pack later by reading back.
    this.bufGW = []; this.bufGB = []; this.rbGW = []; this.rbGB = [];
    for(let l=0; l<L; l++){
      const gw = this._F32buf(S[l+1]*S[l], RW);
      const gb = this._F32buf(S[l+1], RW);
      this.bufGW.push(gw); this.bufGB.push(gb);
      this.rbGW.push(dev.createBuffer({size: S[l+1]*S[l]*4, usage: RDS}));
      this.rbGB.push(dev.createBuffer({size: S[l+1]*4, usage: RDS}));
    }
    // Per layer-transition bind groups for grad:
    // Pre layer = state[l-1] (sizes[l]), Post layer = state[l] (sizes[l+1])
    // For l=0 (input transition): Pre = bufWin, Post = state[0] (sizes[1])
    // For l>0: Pre = state[l-1] (sizes[l]), Post = state[l] (sizes[l+1])
    this.bgG = [];
    for(let l=0; l<L; l++){
      const UpreP = (l===0) ? this.bufWin : this.bufU[PHASE_P][l-1];
      const UpreM = (l===0) ? this.bufWin : this.bufU[PHASE_M][l-1];
      const UpostP = this.bufU[PHASE_P][l];
      const UpostM = this.bufU[PHASE_M][l];
      this.bgG.push(dev.createBindGroup({layout: this.bglG, entries:[
        {binding:0, resource:{buffer:this.bufP_grad[l]}},
        {binding:1, resource:{buffer:UpreP}},
        {binding:2, resource:{buffer:UpreM}},
        {binding:3, resource:{buffer:UpostP}},
        {binding:4, resource:{buffer:UpostM}},
        {binding:5, resource:{buffer:this.bufR}},
        {binding:6, resource:{buffer:this.bufGW[l]}},
        {binding:7, resource:{buffer:this.bufGB[l]}},
      ]}));
    }
    // Reward bind group (uses output layer free-phase state)
    const Uo_free = this.bufU[PHASE_F][L-1];
    this.bgRew = dev.createBindGroup({layout: this.bglA, entries:[
      {binding:0, resource:{buffer:this.bufP_rew}},
      {binding:1, resource:{buffer:Uo_free}},
      {binding:2, resource:{buffer:this.bufTgt}},
      {binding:3, resource:{buffer:this.bufR}},
      {binding:4, resource:{buffer:this.bufDummyR}},
      {binding:5, resource:{buffer:this.bufDummyRW1}},
      {binding:6, resource:{buffer:this.bufDummyRW2}},
    ]});
    // Adaptation bind groups (per layer): adapt Up,Um toward Uf
    this.bgAdapt = [];
    for(let l=1; l<=L; l++){
      this.bgAdapt.push(dev.createBindGroup({layout: this.bglA, entries:[
        {binding:0, resource:{buffer:this.bufP_adapt[l-1]}},
        {binding:1, resource:{buffer:this.bufDummyR}},
        {binding:2, resource:{buffer:this.bufDummyR}},
        {binding:3, resource:{buffer:this.bufDummyRW3}},
        {binding:4, resource:{buffer:this.bufU[PHASE_F][l-1]}},
        {binding:5, resource:{buffer:this.bufU[PHASE_P][l-1]}},
        {binding:6, resource:{buffer:this.bufU[PHASE_M][l-1]}},
      ]}));
    }
  }

  _writeRelaxParams(buf, {ni, no, nxt, B, dt, beta, gamma, mode, has_topdown, has_target, noise_scale=0, iter_seed=0, clamp_lo=0, clamp_hi=0}){
    const buf32 = new ArrayBuffer(64);
    const u32 = new Uint32Array(buf32); const f32 = new Float32Array(buf32);
    u32[0]=ni; u32[1]=no; u32[2]=nxt; u32[3]=B;
    f32[4]=dt; f32[5]=beta; f32[6]=gamma; f32[7]=mode;
    u32[8]=has_topdown; u32[9]=has_target;
    f32[10]=noise_scale;
    u32[11]=iter_seed;
    f32[12]=clamp_lo; f32[13]=clamp_hi; f32[14]=0; f32[15]=0;
    this.dev.queue.writeBuffer(buf, 0, buf32);
  }
  _writeGradParams(buf, {ni, no, B, two_beta, mode_pre, mode_post}){
    const buf32 = new ArrayBuffer(32);
    const u32 = new Uint32Array(buf32); const f32 = new Float32Array(buf32);
    u32[0]=ni; u32[1]=no; u32[2]=0; u32[3]=B;
    f32[4]=0; f32[5]=two_beta; f32[6]=mode_pre; f32[7]=mode_post;
    this.dev.queue.writeBuffer(buf, 0, buf32);
  }
  _writeAuxParams(buf, {B, O, c, mode}){
    const buf32 = new ArrayBuffer(32);
    const u32 = new Uint32Array(buf32); const f32 = new Float32Array(buf32);
    u32[0]=B; u32[1]=O; u32[2]=0; u32[3]=0;
    f32[4]=c; f32[5]=mode; f32[6]=0; f32[7]=0;
    this.dev.queue.writeBuffer(buf, 0, buf32);
  }

  uploadWeights(W, b){
    const q = this.dev.queue;
    for(let l=0; l<this.L; l++){
      q.writeBuffer(this.bufW[l], 0, W[l].buffer, W[l].byteOffset, W[l].byteLength);
      q.writeBuffer(this.bufB[l], 0, b[l].buffer, b[l].byteOffset, b[l].byteLength);
    }
  }
  // HPSN: set per-neuron time constants. If tauMax > tauMin > 0, samples Uniform[tauMin, tauMax].
  // Otherwise fills with constant scalarTau (backward-compat with old fixed-dt EqProp).
  setTau(layerIdx, scalarTau, tauMin=0, tauMax=0, seed=42){
    const no = this.sizes[layerIdx+1];
    const arr = new Float32Array(no);
    if(tauMax > tauMin && tauMin > 0){
      // Deterministic LCG for reproducible per-neuron tau distribution.
      let s = (seed>>>0) || 1;
      const rng = ()=>{ s = (Math.imul(s, 1664525) + 1013904223) >>> 0; return s/4294967296; };
      for(let i=0;i<no;i++) arr[i] = tauMin + rng()*(tauMax - tauMin);
    } else {
      arr.fill(scalarTau);
    }
    this.dev.queue.writeBuffer(this.bufTau[layerIdx], 0, arr.buffer, arr.byteOffset, arr.byteLength);
    return arr;  // return so caller can inspect distribution
  }
  // Convenience: set ALL layers to the same (scalar or distribution) tau spec.
  setAllTau(scalarTau, tauMin=0, tauMax=0, seed=42){
    for(let l=0; l<this.L; l++) this.setTau(l, scalarTau, tauMin, tauMax, seed + l*1000);
  }
  uploadInputs(X, T){
    const q = this.dev.queue;
    q.writeBuffer(this.bufWin, 0, X.buffer, X.byteOffset, X.byteLength);
    q.writeBuffer(this.bufTgt, 0, T.buffer, T.byteOffset, T.byteLength);
  }

  _initAllPhases(enc){
    const L = this.L;
    const MAX_WG_X = 65535;
    for(let phase=0; phase<3; phase++){
      for(let l=1; l<=L; l++){
        const n = this.B * this.sizes[l];
        const wgTotal = Math.ceil(n/64);
        // 2D dispatch when wgTotal exceeds per-dim limit
        const wgX = Math.min(wgTotal, MAX_WG_X);
        const wgY = Math.ceil(wgTotal / MAX_WG_X);
        const pass = enc.beginComputePass();
        pass.setPipeline(this.pipeInit);
        pass.setBindGroup(0, this.bgInit[phase][l-1]);
        pass.dispatchWorkgroups(wgX, wgY);
        pass.end();
      }
    }
  }
  // CRITICAL: each layer update must be in its OWN compute pass so the GPU sees
  // a barrier between writes to layer l's state and reads of that state by layer l+1.
  // WebGPU has no implicit synchronization between dispatches within a single pass.
  _runPhaseRelax(enc, phase, iters){
    const L = this.L, B = this.B;
    for(let t=0; t<iters; t++){
      for(let l=1; l<=L; l++){
        const pass = enc.beginComputePass();
        pass.setPipeline(this.pipeRelax);
        pass.setBindGroup(0, this.bgR[phase][l-1]);
        const no = this.sizes[l];
        pass.dispatchWorkgroups(Math.ceil(no/8), Math.ceil(B/8));
        pass.end();
      }
    }
  }
  // Write all uniform buffers for the pass.
  // noiseScale > 0 enables sEqProp; seedBase is added to iteration counter for per-call variation.
  _writeAllUniformsForPass(dt, beta, noiseScale=0, seedBase=0){
    const S=this.sizes, L=this.L, B=this.B, gam=this.gamma, mf=this.modeFlag;
    const phaseBetas = [0, +beta, -beta];   // free, plus, minus
    const ns = (typeof this.noiseScale === 'number') ? this.noiseScale : noiseScale;
    const sb = (typeof this.noiseSeedBase === 'number') ? this.noiseSeedBase : seedBase;
    const cLo = this.driveClampLo || 0;
    const cHi = this.driveClampHi || 0;
    // Relax uniforms (per phase, per layer). iter_seed is incremented per call below.
    for(let phase=0; phase<3; phase++){
      for(let l=1; l<=L; l++){
        const isOut = (l === L);
        const isHid = !isOut;
        const ni = S[l-1], no = S[l], nxt = isHid ? S[l+1] : 0;
        const phaseBeta = (isOut) ? phaseBetas[phase] : 0;
        this._writeRelaxParams(this.bufP_relax[phase][l-1], {
          ni, no, nxt, B, dt, beta: phaseBeta, gamma: gam, mode: mf,
          has_topdown: isHid ? 1 : 0, has_target: isOut ? 1 : 0,
          noise_scale: ns,
          iter_seed: (sb + phase * 7919 + (l-1) * 1009) >>> 0,
          clamp_lo: cLo, clamp_hi: cHi,
        });
      }
    }
    for(let l=1; l<=L; l++){
      this._writeRelaxParams(this.bufP_init[l-1], {
        ni: S[l-1], no: S[l], nxt: 0, B, dt, beta: 0, gamma: gam, mode: mf, has_topdown: 0, has_target: 0,
        noise_scale: 0, iter_seed: 0,   // init kernel doesn't use noise
        clamp_lo: 0, clamp_hi: 0,        // init kernel doesn't run drive — clamp irrelevant
      });
    }
  }
  // Tier A — runtime setter for drive clamp. Pass (0,0) to disable.
  setDriveClamp(lo, hi){
    this.driveClampLo = lo;
    this.driveClampHi = hi;
  }
  // sEqProp: set per-pass noise scale and seed base. Call before runFreeAndReadOutputs / runOnePass.
  setSEqPropNoise(noiseScale, seedBase){
    this.noiseScale = noiseScale;
    this.noiseSeedBase = (seedBase >>> 0) || 0;
  }
  _runReward(enc){
    this._writeAuxParams(this.bufP_rew, {B: this.B, O: this.O, c: 0, mode: this.modeFlag});
    const pass = enc.beginComputePass();
    pass.setPipeline(this.pipeReward);
    pass.setBindGroup(0, this.bgRew);
    pass.dispatchWorkgroups(Math.ceil(this.B/64));
    pass.end();
  }
  _runAdaptation(enc, adpC, adpSteps){
    if(adpSteps <= 0 || this.mode === 'fhn') return;  // skip adaptation in FHN mode
    const L = this.L;
    for(let l=1; l<=L; l++){
      this._writeAuxParams(this.bufP_adapt[l-1], {B: this.B, O: this.O, c: adpC, mode: this.modeFlag});
    }
    const MAX_WG_X = 65535;
    for(let a=0; a<adpSteps; a++){
      for(let l=1; l<=L; l++){
        const n = this.B * this.sizes[l];
        const wgTotal = Math.ceil(n/64);
        const wgX = Math.min(wgTotal, MAX_WG_X);
        const wgY = Math.ceil(wgTotal / MAX_WG_X);
        const pass = enc.beginComputePass();
        pass.setPipeline(this.pipeAdapt);
        pass.setBindGroup(0, this.bgAdapt[l-1]);
        pass.dispatchWorkgroups(wgX, wgY);
        pass.end();
      }
    }
  }
  _runGrad(enc, beta){
    const L = this.L, B = this.B;
    // Grad uniforms per layer transition
    for(let l=0; l<L; l++){
      const ni = this.sizes[l], no = this.sizes[l+1];
      // Determine ρ modes:
      //   FHN: both pre/post are clip (1).
      //   Adaptive: σ for both EXCEPT input layer (l=0) where Win is treated with σ (mode_pre=0 always since adaptive).
      const mode_pre = this.modeFlag;   // 0 for adaptive, 1 for fhn — applies to all states
      const mode_post = this.modeFlag;
      this._writeGradParams(this.bufP_grad[l], {ni, no, B, two_beta: 2*beta, mode_pre, mode_post});
    }
    for(let l=0; l<L; l++){
      const ni = this.sizes[l], no = this.sizes[l+1];
      const pass = enc.beginComputePass();
      pass.setPipeline(this.pipeGW);
      pass.setBindGroup(0, this.bgG[l]);
      pass.dispatchWorkgroups(Math.ceil(ni/8), Math.ceil(no/8));
      pass.setPipeline(this.pipeGB);
      pass.setBindGroup(0, this.bgG[l]);
      pass.dispatchWorkgroups(Math.ceil(no/64));
      pass.end();
    }
  }

  async runFreeAndReadOutputs(iters, dt){
    if(!this.useHPSN){
      if(this._lastTauDt !== dt){ this.setAllTau(dt); this._lastTauDt = dt; }
    }
    this._writeAllUniformsForPass(dt, 0);   // beta=0 → all phases free
    const enc = this.dev.createCommandEncoder();
    this._initAllPhases(enc);
    this._runPhaseRelax(enc, PHASE_F, iters);
    const O = this.O;
    enc.copyBufferToBuffer(this.bufU[PHASE_F][this.L-1], 0, this.rbUoF, 0, this.B*O*4);
    this.dev.queue.submit([enc.finish()]);
    await this.rbUoF.mapAsync(GPUMapMode.READ);
    const r = new Float32Array(this.rbUoF.getMappedRange().slice(0));
    this.rbUoF.unmap();
    return r;
  }
  // Free-phase relax, then read back the activations of an arbitrary internal layer (l in [1..L]).
  async runFreeAndReadLayer(iters, dt, layerIdx){
    if(!this.useHPSN){
      if(this._lastTauDt !== dt){ this.setAllTau(dt); this._lastTauDt = dt; }
    }
    if(layerIdx < 1 || layerIdx > this.L) throw new Error('layerIdx out of range');
    this._writeAllUniformsForPass(dt, 0);
    const enc = this.dev.createCommandEncoder();
    this._initAllPhases(enc);
    this._runPhaseRelax(enc, PHASE_F, iters);
    const size = this.B * this.sizes[layerIdx] * 4;
    const rb = this.dev.createBuffer({size, usage: GPUBufferUsage.COPY_DST|GPUBufferUsage.MAP_READ});
    enc.copyBufferToBuffer(this.bufU[PHASE_F][layerIdx-1], 0, rb, 0, size);
    this.dev.queue.submit([enc.finish()]);
    await rb.mapAsync(GPUMapMode.READ);
    const r = new Float32Array(rb.getMappedRange().slice(0));
    rb.unmap(); rb.destroy?.();
    return r;
  }

  async runOnePassGetGradients({itF=8, itN=5, dt=0.7, beta=0.5, adpC=0.15, adpSteps=3}={}){
    if(this.mode === 'fhn') adpSteps = 0;
    // HPSN backward-compat: when not using heterogeneous-τ, refresh Tau to match runtime dt.
    // When useHPSN=true, the user-set heterogeneous distribution is preserved (Tau not overwritten).
    if(!this.useHPSN){
      if(this._lastTauDt !== dt){ this.setAllTau(dt); this._lastTauDt = dt; }
    }
    this._writeAllUniformsForPass(dt, beta);

    const enc = this.dev.createCommandEncoder();
    this._initAllPhases(enc);
    this._runPhaseRelax(enc, PHASE_F, itF);
    this._runPhaseRelax(enc, PHASE_P, itN);
    this._runPhaseRelax(enc, PHASE_M, itN);
    this._runReward(enc);
    this._runAdaptation(enc, adpC, adpSteps);
    this._runGrad(enc, beta);
    // Readback all gradients (separate buffers per layer) + Uo_free
    for(let l=0; l<this.L; l++){
      enc.copyBufferToBuffer(this.bufGW[l], 0, this.rbGW[l], 0, this.sizes[l+1]*this.sizes[l]*4);
      enc.copyBufferToBuffer(this.bufGB[l], 0, this.rbGB[l], 0, this.sizes[l+1]*4);
    }
    enc.copyBufferToBuffer(this.bufU[PHASE_F][this.L-1], 0, this.rbUoF, 0, this.B*this.O*4);
    this.dev.queue.submit([enc.finish()]);

    const maps = [this.rbUoF.mapAsync(GPUMapMode.READ)];
    for(let l=0; l<this.L; l++){ maps.push(this.rbGW[l].mapAsync(GPUMapMode.READ)); maps.push(this.rbGB[l].mapAsync(GPUMapMode.READ)); }
    await Promise.all(maps);
    const uoF = new Float32Array(this.rbUoF.getMappedRange().slice(0));
    this.rbUoF.unmap();
    const gW = [], gB = [];
    for(let l=0; l<this.L; l++){
      gW.push(new Float32Array(this.rbGW[l].getMappedRange().slice(0)));
      gB.push(new Float32Array(this.rbGB[l].getMappedRange().slice(0)));
      this.rbGW[l].unmap(); this.rbGB[l].unmap();
    }
    return {gW, gB, uoF};
  }
  destroy(){
    const bufs = [this.bufWin, this.bufTgt, this.bufR, this.bufDummyR, this.bufDummyRW1, this.bufDummyRW2, this.bufDummyRW3, this.bufG, this.rbG, this.rbUoF, this.bufP_rew];
    for(const arr of [this.bufW, this.bufB, this.bufGW, this.bufGB, this.rbGW, this.rbGB, this.bufP_init, this.bufP_grad, this.bufP_adapt, this.bufTau]) bufs.push(...arr);
    for(const ph of this.bufU) bufs.push(...ph);
    for(const ph of this.bufP_relax) bufs.push(...ph);
    for(const v of bufs) if(v && v.destroy) try{ v.destroy(); }catch(e){}
  }
}

// Multi-layer AdaGO optimizer.
export class AdaGODeep {
  constructor(sizes, {OW_K=8}={}){
    this.sizes = sizes; this.L = sizes.length - 1;
    this.MW=[]; this.MB=[]; this.vW=new Float64Array(this.L); this.vB=new Float64Array(this.L);
    this.OW=new Array(this.L).fill(null); this.OW_K=OW_K; this.bc=0;
    for(let l=0; l<this.L; l++){
      this.MW.push(new Float64Array(sizes[l+1]*sizes[l]));
      this.MB.push(new Float64Array(sizes[l+1]));
    }
  }
  step(l, W, B, gW, gB, bs, lr){
    const eps=1e-8, mu=0.9, gam=1.0;
    const ni=this.sizes[l], no=this.sizes[l+1];
    let gn2=0; for(let k=0;k<W.length;k++){ const g=gW[k]/bs; this.MW[l][k]=mu*this.MW[l][k]+(1-mu)*g; gn2+=g*g; }
    const gn=Math.sqrt(gn2); this.vW[l]+=Math.min(gn2,gam*gam);
    if(!this.OW[l] || this.bc%this.OW_K===0) this.OW[l]=orthCPU(this.MW[l], no, ni);
    const O=this.OW[l];
    const alpha=Math.max(eps, lr*Math.min(gn,gam)/(Math.sqrt(this.vW[l])+eps));
    for(let k=0;k<W.length;k++) W[k]+= alpha*O[k];
    let bn2=0; for(let k=0;k<B.length;k++){ const g=gB[k]/bs; this.MB[l][k]=mu*this.MB[l][k]+(1-mu)*g; bn2+=g*g; }
    const bn=Math.sqrt(bn2); this.vB[l]+=Math.min(bn2,gam*gam);
    const ba=Math.max(eps, lr*Math.min(bn,gam)/(Math.sqrt(this.vB[l])+eps));
    for(let k=0;k<B.length;k++) B[k]+= ba*(bn>0?this.MB[l][k]/bn:0);
  }
  endBatch(){ this.bc++; }
}

// Adam (with optional weight decay → AdamW). EqProp gives an ascent direction, so we += step.
export class Adam {
  constructor(sizes, {beta1=0.9, beta2=0.999, eps=1e-8, weightDecay=0}={}){
    this.sizes=sizes; this.L=sizes.length-1; this.beta1=beta1; this.beta2=beta2; this.eps=eps;
    this.wd = weightDecay;   // AdamW-style decoupled weight decay (applied to W only, not bias)
    this.mW=[]; this.vW=[]; this.mB=[]; this.vB=[]; this.t=0;
    for(let l=0; l<this.L; l++){
      this.mW.push(new Float64Array(sizes[l+1]*sizes[l]));
      this.vW.push(new Float64Array(sizes[l+1]*sizes[l]));
      this.mB.push(new Float64Array(sizes[l+1]));
      this.vB.push(new Float64Array(sizes[l+1]));
    }
  }
  step(l, W, B, gW, gB, bs, lr){
    this.t++;
    const b1=this.beta1, b2=this.beta2, eps=this.eps;
    const bc1 = 1 - Math.pow(b1, this.t), bc2 = 1 - Math.pow(b2, this.t);
    const wd = this.wd;
    for(let k=0;k<W.length;k++){
      const g = gW[k]/bs;
      this.mW[l][k] = b1*this.mW[l][k] + (1-b1)*g;
      this.vW[l][k] = b2*this.vW[l][k] + (1-b2)*g*g;
      const m_hat = this.mW[l][k]/bc1, v_hat = this.vW[l][k]/bc2;
      // AdamW: decoupled decay (W *= 1 - lr*wd) ; EqProp ascent => + m̂/√v̂
      if(wd > 0) W[k] *= (1 - lr * wd);
      W[k] += lr * m_hat / (Math.sqrt(v_hat) + eps);
    }
    for(let k=0;k<B.length;k++){
      const g = gB[k]/bs;
      this.mB[l][k] = b1*this.mB[l][k] + (1-b1)*g;
      this.vB[l][k] = b2*this.vB[l][k] + (1-b2)*g*g;
      const m_hat = this.mB[l][k]/bc1, v_hat = this.vB[l][k]/bc2;
      B[k] += lr * m_hat / (Math.sqrt(v_hat) + eps);
    }
  }
  endBatch(){}
}

// Muon optimizer (Keller Jordan's MomentUm Orthogonalized by Newton-schulz).
// Steps in sign-of-singular-values direction. Per-step Newton-Schulz quintic on momentum.
// Uses shape-aware scaling: step = lr * sqrt(max(m,n)/min(m,n)) * Orth(M).
// Coefficients from K. Jordan's original Muon: a, b, c chosen so f(x)=ax+bx³+cx⁵ pushes
// singular values toward 1 (with controlled overshoot).
function muonOrth(M_in, m, n, iters=5){
  // Normalize so spectral norm ≤ 1 (use Frobenius norm as upper bound).
  // Then apply iters of  X ← a X + b (XXᵀ)X + c (XXᵀ)² X  (Muon's quintic Newton-Schulz)
  let X = new Float64Array(M_in.length); for(let k=0;k<M_in.length;k++) X[k]=M_in[k];
  let nrm=0; for(const x of X) nrm += x*x; nrm = Math.sqrt(nrm) + 1e-30;
  for(let k=0;k<X.length;k++) X[k] /= nrm;
  // Decide which side is smaller (use m≥n form: work with X[m,n], operate on (XᵀX) [n,n] is smaller if n<m)
  // For generality just do the quintic in M form using square root tricks. Use the m≥n path; if n>m, transpose.
  const transp = (n > m);
  let R = m, C = n;
  if(transp){
    // swap to make R ≥ C
    const T = new Float64Array(n*m);
    for(let i=0;i<m;i++) for(let j=0;j<n;j++) T[j*m+i] = X[i*n+j];
    X = T; R = n; C = m;
  }
  const a = 3.4445, b = -4.7750, c = 2.0315;
  // helper: matmul A(p,q)·B(q,r) → out(p,r)
  function mm(A, B, p, q, r){
    const O = new Float64Array(p*r);
    for(let i=0;i<p;i++) for(let k=0;k<q;k++){ const aa=A[i*q+k]; if(aa) for(let j=0;j<r;j++) O[i*r+j]+=aa*B[k*r+j]; }
    return O;
  }
  function transpose(A, p, q){ const T=new Float64Array(p*q); for(let i=0;i<p;i++) for(let j=0;j<q;j++) T[j*p+i]=A[i*q+j]; return T; }
  for(let it=0; it<iters; it++){
    // We want X ← a X + b X(XᵀX) + c X(XᵀX)²
    // Compute G = XᵀX  (C×C)
    const Xt = transpose(X, R, C);
    const G = mm(Xt, X, C, R, C);    // (C×C)
    const G2 = mm(G, G, C, C, C);     // (C×C)
    const XG = mm(X, G, R, C, C);     // (R×C)
    const XG2 = mm(X, G2, R, C, C);   // (R×C)
    const Y = new Float64Array(R*C);
    for(let k=0;k<R*C;k++) Y[k] = a*X[k] + b*XG[k] + c*XG2[k];
    X = Y;
  }
  // Transpose back if needed
  if(transp){
    const O = new Float64Array(m*n);
    for(let i=0;i<R;i++) for(let j=0;j<C;j++) O[j*m+i] = X[i*C+j];   // X was C×R after transp → output m×n
    // wait: when transp we used X[n×m], R=n C=m. So X[i*C+j] is X[n][m]. We want output[m,n] = transpose.
    // O[j*n+i] = X[i*m+j] would mean output_row_j_col_i = X_row_i_col_j. Let me redo.
    const out = new Float64Array(m*n);
    // X is R×C = n×m, so X[i,j] for i in 0..n, j in 0..m. We want M_orth[m,n] = (X_transposed)[a,b] = X[b,a].
    for(let a=0;a<m;a++) for(let b=0;b<n;b++) out[a*n+b] = X[b*m+a];
    return new Float32Array(out);
  }
  return new Float32Array(X);
}
export class Muon {
  constructor(sizes, {beta=0.95, weightDecay=0, iters=5}={}){
    this.sizes=sizes; this.L=sizes.length-1; this.beta=beta; this.wd=weightDecay; this.iters=iters;
    this.MW=[]; this.mB=[];
    for(let l=0; l<this.L; l++){
      this.MW.push(new Float64Array(sizes[l+1]*sizes[l]));
      this.mB.push(new Float64Array(sizes[l+1]));
    }
  }
  step(l, W, B, gW, gB, bs, lr){
    const beta=this.beta, wd=this.wd;
    const no = this.sizes[l+1], ni = this.sizes[l];
    // Momentum update (Muon uses Nesterov-style momentum)
    for(let k=0;k<W.length;k++){
      const g = gW[k]/bs;
      this.MW[l][k] = beta*this.MW[l][k] + g;
    }
    // Orthogonalize momentum via quintic NS
    const O = muonOrth(this.MW[l], no, ni, this.iters);
    // Shape-aware scaling: lr · sqrt(max/min)
    const scale = lr * Math.sqrt(Math.max(no, ni) / Math.min(no, ni));
    // Step (ASCEND since EqProp gives ascent direction)
    for(let k=0;k<W.length;k++){
      if(wd>0) W[k] *= (1 - lr*wd);
      W[k] += scale * O[k];
    }
    // Bias: plain momentum (Muon spec says biases get separate Adam-like; here just SGD-momentum for simplicity)
    for(let k=0;k<B.length;k++){
      const g = gB[k]/bs;
      this.mB[l][k] = beta*this.mB[l][k] + g;
      B[k] += lr * this.mB[l][k];
    }
  }
  endBatch(){}
}

// Lion optimizer (sign of momentum). Often outperforms Adam for some tasks, less memory.
export class Lion {
  constructor(sizes, {beta1=0.9, beta2=0.99, weightDecay=0}={}){
    this.sizes=sizes; this.L=sizes.length-1; this.beta1=beta1; this.beta2=beta2; this.wd=weightDecay;
    this.mW=[]; this.mB=[];
    for(let l=0; l<this.L; l++){
      this.mW.push(new Float64Array(sizes[l+1]*sizes[l]));
      this.mB.push(new Float64Array(sizes[l+1]));
    }
  }
  step(l, W, B, gW, gB, bs, lr){
    const b1=this.beta1, b2=this.beta2, wd=this.wd;
    for(let k=0;k<W.length;k++){
      const g = gW[k]/bs;
      // update direction: sign(b1*m + (1-b1)*g)
      const c = b1*this.mW[l][k] + (1-b1)*g;
      const u = c >= 0 ? 1 : -1;
      if(wd>0) W[k] *= (1 - lr*wd);
      W[k] += lr * u;
      // momentum update with b2
      this.mW[l][k] = b2*this.mW[l][k] + (1-b2)*g;
    }
    for(let k=0;k<B.length;k++){
      const g = gB[k]/bs;
      const c = b1*this.mB[l][k] + (1-b1)*g;
      const u = c >= 0 ? 1 : -1;
      B[k] += lr * u;
      this.mB[l][k] = b2*this.mB[l][k] + (1-b2)*g;
    }
  }
  endBatch(){}
}