Spaces:
Running
Running
File size: 44,184 Bytes
d035fbd | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 | // Multi-hidden-layer WebGPU EqProp trainer.
// sizes = [D, H1, H2, ..., Hk, O]
// Modes:
// 0 = adaptive (σ everywhere, u-nudge at output)
// 1 = fhn (clip everywhere, ρ-nudge at output)
//
// Key design:
// * ONE generic compute pipeline `pass_layer` updates any layer using uniforms (ni, no, nxt, ...).
// * 3 phases (free / plus / minus): each has its own per-layer state buffer + per-layer bind group.
// * Uniform buffers are pre-written once per pass (one per (phase, layer) pair), so the encoder
// records dispatches that each pick up the right uniforms.
// * Gradient kernels are also generic — per layer transition, compute outer product reduction over batch.
import { orth as orthCPU } from './eqprop_lib.js';
// ----- WGSL: relax (one kernel handles any layer) -----
const WGSL_RELAX = `
struct P {
ni : u32, no : u32, nxt : u32, B : u32,
dt : f32, beta : f32, gamma : f32, mode : f32, // mode: 0=adaptive σ, 1=fhn(clip+cubic), 2=prism (soft-clip via softplus)
has_topdown : u32, has_target : u32, noise_scale : f32, iter_seed : u32,
// sEqProp: noise_scale > 0 injects per-iter per-neuron Gaussian-ish noise into drive.
// Bio-faithful (real synapses are stochastic). At test, run M passes & average outputs.
clamp_lo : f32, clamp_hi : f32, _p_t1 : f32, _p_t2 : f32,
// Tier A — pre-σ drive clamp (algorithmic, uniform-driven). Bounds the pre-activation
// value c before σ(c) to prevent saturation runaway. When clamp_hi <= clamp_lo the
// kernel treats it as DISABLED (no-op). Default in caller = clamp_lo=clamp_hi=0 → disabled.
};
@group(0) @binding(0) var<uniform> p : P;
@group(0) @binding(1) var<storage, read> Win : array<f32>;
@group(0) @binding(2) var<storage, read> W0 : array<f32>; // [no x ni]
@group(0) @binding(3) var<storage, read> b0 : array<f32>; // [no]
@group(0) @binding(4) var<storage, read> W1 : array<f32>; // [nxt x no] (top-down)
@group(0) @binding(5) var<storage, read_write> Uh : array<f32>; // [B*no]
@group(0) @binding(6) var<storage, read> Uo : array<f32>; // [B*nxt]
@group(0) @binding(7) var<storage, read> Tgt : array<f32>; // [B*no]
// HPSN: heterogeneous time constants — per-neuron multiplier on drive integration.
// Tau[i] replaces the global p.dt. Constant Tau[i]=p.dt → behavior identical to scalar-dt EqProp.
// Sampled from Uniform[tau_min, tau_max] → diverse temporal scales like real cortical neurons.
@group(0) @binding(8) var<storage, read> Tau : array<f32>; // [no]
const A1 : f32 = 0.07407407407407407;
const PRISM_K : f32 = 10.0; // sharpness; higher k → harder clip
// PCG-style cheap hash → uniform [0,1). Per-thread, per-iter, per-neuron stochasticity.
fn pcg_hash(seed_in: u32) -> u32 {
var state : u32 = seed_in * 747796405u + 2891336453u;
let word : u32 = ((state >> ((state >> 28u) + 4u)) ^ state) * 277803737u;
return (word >> 22u) ^ word;
}
fn unif_noise(b: u32, i: u32, t: u32) -> f32 {
// Triangular distribution (sum of 2 uniforms - 1) ≈ Gaussian-ish, mean=0, variance=1/6.
let h1 = pcg_hash(b * 65537u + i * 257u + t * 31u);
let h2 = pcg_hash(b * 31337u + i * 1031u + t * 17u + 12345u);
let u1 = f32(h1) / 4294967296.0;
let u2 = f32(h2) / 4294967296.0;
return (u1 + u2) - 1.0; // range [-1, 1], roughly triangular
}
fn sg(u: f32) -> f32 { return 1.0 / (1.0 + exp(-4.0 * (u - 0.5))); }
fn fhn_rho(u: f32) -> f32 { return clamp(u, 0.0, 1.0); }
fn fhn_rho_p(u: f32) -> f32 { return select(0.0, 1.0, u > 0.0 && u < 1.0); }
fn fhn_f(u: f32) -> f32 { return A1 * u - u*u*u; }
// PRISM activation: ρ(u) = (softplus(k·u) - softplus(k·(u-1))) / k
// Smooth approximation of clip(u,0,1). Derivative: σ(k·u) - σ(k·(u-1)).
// "Prism" = splits drive into a smooth-yet-saturating activation with gradient flow on both sides.
fn softplus(x: f32) -> f32 { return select(x + log(1.0 + exp(-x)), log(1.0 + exp(x)), x <= 0.0); }
fn prism_rho(u: f32) -> f32 {
return (softplus(PRISM_K * u) - softplus(PRISM_K * (u - 1.0))) / PRISM_K;
}
fn prism_rho_p(u: f32) -> f32 {
let k = PRISM_K;
return 1.0/(1.0+exp(-k*u)) - 1.0/(1.0+exp(-k*(u-1.0)));
}
fn rho(u: f32) -> f32 {
if (p.mode > 1.5) { return prism_rho(u); }
if (p.mode > 0.5) { return fhn_rho(u); }
return sg(u);
}
@compute @workgroup_size(8, 8) fn pass_layer(@builtin(global_invocation_id) gid: vec3<u32>) {
let b = gid.y; let i = gid.x;
if (b >= p.B || i >= p.no) { return; }
// bottom-up: c = b0[i] + sum_j W0[i,j] * rho(Win[b,j])
var c : f32 = b0[i];
let row0 = i * p.ni;
let xoff = b * p.ni;
for (var j: u32 = 0u; j < p.ni; j = j + 1u) {
c = c + W0[row0 + j] * rho(Win[xoff + j]);
}
// top-down: gamma * sum_k W1[k,i] * rho(Uo[b,k]) (if next layer exists)
if (p.has_topdown != 0u) {
var td : f32 = 0.0;
let uo_off = b * p.nxt;
for (var k: u32 = 0u; k < p.nxt; k = k + 1u) {
td = td + W1[k * p.no + i] * rho(Uo[uo_off + k]);
}
if (p.mode > 0.5) { c = c + td; } else { c = c + p.gamma * td; }
}
// Tier A — pre-σ drive clamp (algorithmic). Active iff clamp_hi > clamp_lo.
if (p.clamp_hi > p.clamp_lo) {
c = clamp(c, p.clamp_lo, p.clamp_hi);
}
let idx = b * p.no + i;
let u_old = Uh[idx];
// sEqProp noise: per-(b, i, iter_seed) triangular noise added to drive. Zero by default.
let noise = select(0.0, p.noise_scale * unif_noise(b, i, p.iter_seed), p.noise_scale > 0.0);
var u_new : f32;
if (p.mode > 1.5) {
// PRISM: u̇ = ρ'(u)·c + (linear pull) ; ρ-nudge for output. Smooth saturating dynamics.
var drive : f32 = prism_rho_p(u_old) * c - 0.1 * (u_old - 0.5) + noise;
if (p.has_target != 0u && p.beta != 0.0) {
drive = drive + p.beta * (Tgt[idx] - prism_rho(u_old));
}
u_new = u_old + Tau[i] * drive;
u_new = clamp(u_new, -0.3, 1.3);
} else if (p.mode > 0.5) {
// FHN
var drive : f32 = fhn_rho_p(u_old) * c + fhn_f(u_old) + noise;
if (p.has_target != 0u && p.beta != 0.0) {
drive = drive + p.beta * (Tgt[idx] - fhn_rho(u_old));
}
u_new = u_old + Tau[i] * drive;
u_new = clamp(u_new, -0.2, 1.2);
} else {
// Adaptive
var drive : f32 = -u_old + sg(c) + noise;
if (p.has_target != 0u && p.beta != 0.0) {
drive = drive + p.beta * (Tgt[idx] - u_old);
}
u_new = u_old + Tau[i] * drive;
}
Uh[idx] = u_new;
}
// 2D dispatch to handle big buffers (B*no can exceed the 65535 per-dim workgroup limit).
@compute @workgroup_size(64) fn init_state(@builtin(global_invocation_id) gid: vec3<u32>) {
let stride = 65535u * 64u; // workgroups_per_X * threads_per_workgroup
let g = gid.y * stride + gid.x;
let n = p.B * p.no;
if (g < n) { Uh[g] = 0.1; }
}
`;
// ----- WGSL: gradient (one kernel handles any layer transition) -----
const WGSL_GRAD = `
struct GP {
ni : u32, no : u32, _pad : u32, B : u32,
c : f32, two_beta : f32, mode_pre : f32, mode_post : f32, // mode_pre/post: 0=σ, 1=clip, 2=identity
};
@group(0) @binding(0) var<uniform> p : GP;
@group(0) @binding(1) var<storage, read> UpreP : array<f32>; // [B*ni] - "input" layer state, plus phase
@group(0) @binding(2) var<storage, read> UpreM : array<f32>; // [B*ni] - minus
@group(0) @binding(3) var<storage, read> UpostP: array<f32>; // [B*no]
@group(0) @binding(4) var<storage, read> UpostM: array<f32>; // [B*no]
@group(0) @binding(5) var<storage, read> R : array<f32>; // [B]
@group(0) @binding(6) var<storage, read_write> gW : array<f32>; // [no*ni]
@group(0) @binding(7) var<storage, read_write> gB : array<f32>; // [no]
fn sg(u: f32) -> f32 { return 1.0 / (1.0 + exp(-4.0 * (u - 0.5))); }
const PRISM_K2 : f32 = 10.0;
fn softplus2(x: f32) -> f32 { return select(x + log(1.0 + exp(-x)), log(1.0 + exp(x)), x <= 0.0); }
fn prism_rho_g(u: f32) -> f32 { return (softplus2(PRISM_K2*u) - softplus2(PRISM_K2*(u-1.0))) / PRISM_K2; }
fn rho_mode(u: f32, m: f32) -> f32 {
if (m > 2.5) { return u; } // identity (linear)
if (m > 1.5) { return prism_rho_g(u); } // prism soft-clip
if (m > 0.5) { return clamp(u, 0.0, 1.0); } // hard-clip (FHN)
return sg(u); // σ (adaptive)
}
@compute @workgroup_size(8, 8) fn grad_W(@builtin(global_invocation_id) gid: vec3<u32>) {
let i = gid.y; let j = gid.x;
if (i >= p.no || j >= p.ni) { return; }
var acc : f32 = 0.0;
for (var b: u32 = 0u; b < p.B; b = b + 1u) {
let rh = R[b];
let ip = rho_mode(UpostP[b * p.no + i], p.mode_post);
let im = rho_mode(UpostM[b * p.no + i], p.mode_post);
let jp = rho_mode(UpreP[b * p.ni + j], p.mode_pre);
let jm = rho_mode(UpreM[b * p.ni + j], p.mode_pre);
acc = acc + rh * (ip * jp - im * jm);
}
gW[i * p.ni + j] = acc / p.two_beta;
}
@compute @workgroup_size(64) fn grad_B(@builtin(global_invocation_id) gid: vec3<u32>) {
let i = gid.x;
if (i >= p.no) { return; }
var acc : f32 = 0.0;
for (var b: u32 = 0u; b < p.B; b = b + 1u) {
let rh = R[b];
let ip = rho_mode(UpostP[b * p.no + i], p.mode_post);
let im = rho_mode(UpostM[b * p.no + i], p.mode_post);
acc = acc + rh * (ip - im);
}
gB[i] = acc / p.two_beta;
}
`;
// ----- WGSL: reward + adaptation (depends on output layer state) -----
const WGSL_AUX = `
struct AP {
B : u32, O : u32, H_max : u32, n_hidden : u32,
c : f32, mode : f32, _p0 : f32, _p1 : f32,
};
@group(0) @binding(0) var<uniform> p : AP;
@group(0) @binding(1) var<storage, read> UoF : array<f32>;
@group(0) @binding(2) var<storage, read> Tgt : array<f32>;
@group(0) @binding(3) var<storage, read_write> R : array<f32>;
// adaptation buffers (variable size; we pass single layer at a time via separate bind groups)
@group(0) @binding(4) var<storage, read> Uf : array<f32>;
@group(0) @binding(5) var<storage, read_write> Up : array<f32>;
@group(0) @binding(6) var<storage, read_write> Um : array<f32>;
fn sg(u: f32) -> f32 { return 1.0 / (1.0 + exp(-4.0 * (u - 0.5))); }
fn rho_out(u: f32) -> f32 {
if (p.mode > 0.5) { return clamp(u, 0.0, 1.0); }
return sg(u);
}
@compute @workgroup_size(64) fn compute_reward(@builtin(global_invocation_id) gid: vec3<u32>) {
let b = gid.x;
if (b >= p.B) { return; }
var loss : f32 = 0.0;
let off = b * p.O;
for (var i: u32 = 0u; i < p.O; i = i + 1u) {
let d = rho_out(UoF[off + i]) - Tgt[off + i];
loss = loss + d * d;
}
let escale : f32 = 0.4;
let rmin : f32 = 0.1;
var r : f32 = loss / escale;
if (r > 1.0) { r = 1.0; }
R[b] = rmin + (1.0 - rmin) * r;
}
// Adjusted Adaptation per layer: Up,Um ← (1-c)*Up + c*Uf. 2D dispatch safe for large buffers.
@compute @workgroup_size(64) fn adapt_layer(@builtin(global_invocation_id) gid: vec3<u32>) {
let stride = 65535u * 64u;
let g = gid.y * stride + gid.x;
if (g >= arrayLength(&Uf)) { return; }
let f = Uf[g];
Up[g] = (1.0 - p.c) * Up[g] + p.c * f;
Um[g] = (1.0 - p.c) * Um[g] + p.c * f;
}
`;
// ----- JS: trainer class -----
export async function makeGPUDeep({powerPreference='high-performance'}={}){
if(!navigator.gpu) throw new Error('no webgpu');
const adapter = await navigator.gpu.requestAdapter({powerPreference});
if(!adapter) throw new Error('no adapter');
const want = {};
const tryKeys = ['maxStorageBuffersPerShaderStage','maxBufferSize','maxStorageBufferBindingSize',
'maxComputeInvocationsPerWorkgroup','maxComputeWorkgroupSizeX','maxComputeWorkgroupStorageSize','maxBindGroups'];
for(const k of tryKeys){ const v=adapter.limits[k]; if(typeof v==='number') want[k]=v; }
const dev = await adapter.requestDevice({requiredLimits: want});
return {adapter, dev, info: adapter.info||{}};
}
const PHASE_F = 0, PHASE_P = 1, PHASE_M = 2;
export class GPUTrainerDeep {
// sizes: [D, H1, H2, ..., Hk, O] — len L+1; L = number of weight matrices = sizes.length-1
// mode: 'adaptive' | 'fhn' | 'prism'
// driveClampLo, driveClampHi: Tier A — pre-σ drive clamp; ACTIVE iff hi > lo. Default 0,0 = disabled.
constructor({dev, sizes, B, mode='adaptive', gamma=0.6, hpsnTauMin=0, hpsnTauMax=0, hpsnSeed=42, driveClampLo=0, driveClampHi=0}={}){
this.dev = dev; this.sizes = sizes;
this.L = sizes.length - 1; // number of weight matrices (transitions)
this.B = B; this.O = sizes[sizes.length-1];
this.mode = mode;
this.modeFlag = (mode==='prism') ? 2.0 : (mode==='fhn' ? 1.0 : 0.0);
this.gamma = gamma;
this.hpsnTauMin = hpsnTauMin;
this.hpsnTauMax = hpsnTauMax;
this.hpsnSeed = hpsnSeed;
this.useHPSN = (hpsnTauMax > hpsnTauMin && hpsnTauMin > 0);
this.driveClampLo = driveClampLo;
this.driveClampHi = driveClampHi;
this._build();
// Initialize Tau buffers — either constant=0.7 (backward compat) or per-neuron Uniform[hpsnTauMin, hpsnTauMax].
if(this.useHPSN){
this.setAllTau(0.7, hpsnTauMin, hpsnTauMax, hpsnSeed);
} else {
this.setAllTau(0.7);
}
}
_F32buf(n, usage){
if(!Number.isFinite(n) || n <= 0){
console.error('BAD _F32buf size', {n, sizes:JSON.stringify(this.sizes), sizesArr:this.sizes, B:this.B, L:this.L, S0:this.sizes&&this.sizes[0], S0type:typeof (this.sizes&&this.sizes[0])});
throw new Error('_F32buf called with non-finite n=' + n + ' sizes=' + JSON.stringify(this.sizes));
}
return this.dev.createBuffer({size:Math.max(4,n*4), usage});
}
_build(){
const dev = this.dev, S = this.sizes, B = this.B, L = this.L;
const RW = GPUBufferUsage.STORAGE|GPUBufferUsage.COPY_SRC|GPUBufferUsage.COPY_DST;
const R = GPUBufferUsage.STORAGE|GPUBufferUsage.COPY_DST|GPUBufferUsage.COPY_SRC;
const UNI = GPUBufferUsage.UNIFORM|GPUBufferUsage.COPY_DST;
const RDS = GPUBufferUsage.COPY_DST|GPUBufferUsage.MAP_READ;
// input + target (shared across phases)
this.bufWin = this._F32buf(B * S[0], R);
this.bufTgt = this._F32buf(B * S[L], R);
// weights & biases (one per transition)
this.bufW = []; this.bufB = [];
for(let l=0; l<L; l++){
this.bufW.push(this._F32buf(S[l+1]*S[l], R));
this.bufB.push(this._F32buf(S[l+1], R));
}
// HPSN: per-layer Tau buffer of size [no]. Default = uniform scalar dt (backward compat).
// User can call setHeterogeneousTau(layer, tauMin, tauMax) to enable HPSN per layer.
this.bufTau = [];
for(let l=0; l<L; l++){ this.bufTau.push(this._F32buf(S[l+1], R)); }
// state buffers: for each of 3 phases, L state buffers (one per non-input layer)
this.bufU = [[],[],[]]; // bufU[phase][l] is layer l+1's state (l=0..L-1, sizes S[1..L])
for(let phase=0; phase<3; phase++){
for(let l=1; l<=L; l++){
this.bufU[phase].push(this._F32buf(B * S[l], RW));
}
}
// reward + dummies (need separate buffers for read-only vs writable slots to avoid aliasing).
this.bufR = this._F32buf(B, RW);
this.bufDummyR = this._F32buf(4, R); // read-only dummy
this.bufDummyRW1 = this._F32buf(4, RW); // writable dummy slot 1
this.bufDummyRW2 = this._F32buf(4, RW); // writable dummy slot 2 (different buffer!)
this.bufDummyRW3 = this._F32buf(4, RW);
// gradient buffer (packed: all gW and gB together)
const gSizes = []; let total=0;
for(let l=0; l<L; l++){ gSizes.push({offW:total, sizW:S[l+1]*S[l], offB:total+S[l+1]*S[l], sizB:S[l+1], total:S[l+1]*S[l]+S[l+1]}); total += gSizes[l].total; }
this.gOff = gSizes; this.gTotal = total;
this.bufG = this._F32buf(total, RW);
this.rbG = dev.createBuffer({size: total*4, usage: RDS});
// readback for output free-phase Uo (for accuracy/loss)
this.rbUoF = dev.createBuffer({size: B * S[L] * 4, usage: RDS});
// ---- pipelines ----
// Relax pipeline (generic)
const modR = dev.createShaderModule({code: WGSL_RELAX});
const sR = (i)=>({binding:i, visibility:GPUShaderStage.COMPUTE, buffer:{type:'read-only-storage'}});
const sRW= (i)=>({binding:i, visibility:GPUShaderStage.COMPUTE, buffer:{type:'storage'}});
const uN = (i)=>({binding:i, visibility:GPUShaderStage.COMPUTE, buffer:{type:'uniform'}});
this.bglR = dev.createBindGroupLayout({entries:[uN(0), sR(1), sR(2), sR(3), sR(4), sRW(5), sR(6), sR(7), sR(8)]});
this.plR = dev.createPipelineLayout({bindGroupLayouts:[this.bglR]});
this.pipeRelax = dev.createComputePipeline({layout:this.plR, compute:{module:modR, entryPoint:'pass_layer'}});
this.pipeInit = dev.createComputePipeline({layout:this.plR, compute:{module:modR, entryPoint:'init_state'}});
// Grad pipeline (generic)
const modG = dev.createShaderModule({code: WGSL_GRAD});
this.bglG = dev.createBindGroupLayout({entries:[uN(0), sR(1), sR(2), sR(3), sR(4), sR(5), sRW(6), sRW(7)]});
this.plG = dev.createPipelineLayout({bindGroupLayouts:[this.bglG]});
this.pipeGW = dev.createComputePipeline({layout:this.plG, compute:{module:modG, entryPoint:'grad_W'}});
this.pipeGB = dev.createComputePipeline({layout:this.plG, compute:{module:modG, entryPoint:'grad_B'}});
// Aux pipeline (reward + adaptation)
const modA = dev.createShaderModule({code: WGSL_AUX});
this.bglA = dev.createBindGroupLayout({entries:[uN(0), sR(1), sR(2), sRW(3), sR(4), sRW(5), sRW(6)]});
this.plA = dev.createPipelineLayout({bindGroupLayouts:[this.bglA]});
this.pipeReward = dev.createComputePipeline({layout:this.plA, compute:{module:modA, entryPoint:'compute_reward'}});
this.pipeAdapt = dev.createComputePipeline({layout:this.plA, compute:{module:modA, entryPoint:'adapt_layer'}});
// ---- uniform buffers (one per (phase, layer)) ----
// Each entry: 48 bytes (12 u32/f32 slots)
this.bufP_relax = [];
for(let phase=0; phase<3; phase++){
this.bufP_relax.push([]);
for(let l=1; l<=L; l++){
this.bufP_relax[phase].push(dev.createBuffer({size: 64, usage: UNI}));
}
}
// init uniform (one per layer, beta=0)
this.bufP_init = [];
for(let l=1; l<=L; l++) this.bufP_init.push(dev.createBuffer({size: 64, usage: UNI}));
// Grad uniform (one per layer transition)
this.bufP_grad = [];
for(let l=0; l<L; l++) this.bufP_grad.push(dev.createBuffer({size: 32, usage: UNI}));
// Aux uniforms: reward + per-layer adaptation
this.bufP_rew = dev.createBuffer({size: 32, usage: UNI});
this.bufP_adapt = []; for(let l=1; l<=L; l++) this.bufP_adapt.push(dev.createBuffer({size: 32, usage: UNI}));
// ---- bind groups ----
// Relax: per (phase, layer)
this.bgR = [[],[],[]];
for(let phase=0; phase<3; phase++){
for(let l=1; l<=L; l++){
// For layer l (1-indexed): Win = state[phase][l-2] if l>1 else bufWin
// W0 = bufW[l-1], b0 = bufB[l-1]
// W1 = bufW[l] (top-down weights to layer l+1), used if l<L
// Uh = state[phase][l-1]
// Uo = state[phase][l] (next layer), used if l<L
// Tgt = bufTgt (used if l==L)
const Win = (l===1) ? this.bufWin : this.bufU[phase][l-2];
const W0 = this.bufW[l-1], b0 = this.bufB[l-1];
const W1 = (l < L) ? this.bufW[l] : this.bufDummyR;
const Uh = this.bufU[phase][l-1];
const Uo = (l < L) ? this.bufU[phase][l] : this.bufDummyR;
const Tgt = this.bufTgt;
this.bgR[phase].push(dev.createBindGroup({layout: this.bglR, entries:[
{binding:0, resource:{buffer:this.bufP_relax[phase][l-1]}},
{binding:1, resource:{buffer:Win}},
{binding:2, resource:{buffer:W0}},
{binding:3, resource:{buffer:b0}},
{binding:4, resource:{buffer:W1}},
{binding:5, resource:{buffer:Uh}},
{binding:6, resource:{buffer:Uo}},
{binding:7, resource:{buffer:Tgt}},
{binding:8, resource:{buffer:this.bufTau[l-1]}},
]}));
}
}
// Init: per (phase, layer) — uses bufP_init (beta=0)
this.bgInit = [[],[],[]];
for(let phase=0; phase<3; phase++){
for(let l=1; l<=L; l++){
const Win = (l===1) ? this.bufWin : this.bufU[phase][l-2];
const W0 = this.bufW[l-1], b0 = this.bufB[l-1];
const W1 = (l < L) ? this.bufW[l] : this.bufDummyR;
const Uh = this.bufU[phase][l-1];
const Uo = (l < L) ? this.bufU[phase][l] : this.bufDummyR;
const Tgt = this.bufTgt;
this.bgInit[phase].push(dev.createBindGroup({layout: this.bglR, entries:[
{binding:0, resource:{buffer:this.bufP_init[l-1]}},
{binding:1, resource:{buffer:Win}},
{binding:2, resource:{buffer:W0}},
{binding:3, resource:{buffer:b0}},
{binding:4, resource:{buffer:W1}},
{binding:5, resource:{buffer:Uh}},
{binding:6, resource:{buffer:Uo}},
{binding:7, resource:{buffer:Tgt}},
{binding:8, resource:{buffer:this.bufTau[l-1]}},
]}));
}
}
// Grad: per layer transition. Note: gW and gB need offsets into bufG.
// Approach: instead of bindings to subranges, we make a SEPARATE buffer per layer for gradients (simpler).
// Pack later by reading back.
this.bufGW = []; this.bufGB = []; this.rbGW = []; this.rbGB = [];
for(let l=0; l<L; l++){
const gw = this._F32buf(S[l+1]*S[l], RW);
const gb = this._F32buf(S[l+1], RW);
this.bufGW.push(gw); this.bufGB.push(gb);
this.rbGW.push(dev.createBuffer({size: S[l+1]*S[l]*4, usage: RDS}));
this.rbGB.push(dev.createBuffer({size: S[l+1]*4, usage: RDS}));
}
// Per layer-transition bind groups for grad:
// Pre layer = state[l-1] (sizes[l]), Post layer = state[l] (sizes[l+1])
// For l=0 (input transition): Pre = bufWin, Post = state[0] (sizes[1])
// For l>0: Pre = state[l-1] (sizes[l]), Post = state[l] (sizes[l+1])
this.bgG = [];
for(let l=0; l<L; l++){
const UpreP = (l===0) ? this.bufWin : this.bufU[PHASE_P][l-1];
const UpreM = (l===0) ? this.bufWin : this.bufU[PHASE_M][l-1];
const UpostP = this.bufU[PHASE_P][l];
const UpostM = this.bufU[PHASE_M][l];
this.bgG.push(dev.createBindGroup({layout: this.bglG, entries:[
{binding:0, resource:{buffer:this.bufP_grad[l]}},
{binding:1, resource:{buffer:UpreP}},
{binding:2, resource:{buffer:UpreM}},
{binding:3, resource:{buffer:UpostP}},
{binding:4, resource:{buffer:UpostM}},
{binding:5, resource:{buffer:this.bufR}},
{binding:6, resource:{buffer:this.bufGW[l]}},
{binding:7, resource:{buffer:this.bufGB[l]}},
]}));
}
// Reward bind group (uses output layer free-phase state)
const Uo_free = this.bufU[PHASE_F][L-1];
this.bgRew = dev.createBindGroup({layout: this.bglA, entries:[
{binding:0, resource:{buffer:this.bufP_rew}},
{binding:1, resource:{buffer:Uo_free}},
{binding:2, resource:{buffer:this.bufTgt}},
{binding:3, resource:{buffer:this.bufR}},
{binding:4, resource:{buffer:this.bufDummyR}},
{binding:5, resource:{buffer:this.bufDummyRW1}},
{binding:6, resource:{buffer:this.bufDummyRW2}},
]});
// Adaptation bind groups (per layer): adapt Up,Um toward Uf
this.bgAdapt = [];
for(let l=1; l<=L; l++){
this.bgAdapt.push(dev.createBindGroup({layout: this.bglA, entries:[
{binding:0, resource:{buffer:this.bufP_adapt[l-1]}},
{binding:1, resource:{buffer:this.bufDummyR}},
{binding:2, resource:{buffer:this.bufDummyR}},
{binding:3, resource:{buffer:this.bufDummyRW3}},
{binding:4, resource:{buffer:this.bufU[PHASE_F][l-1]}},
{binding:5, resource:{buffer:this.bufU[PHASE_P][l-1]}},
{binding:6, resource:{buffer:this.bufU[PHASE_M][l-1]}},
]}));
}
}
_writeRelaxParams(buf, {ni, no, nxt, B, dt, beta, gamma, mode, has_topdown, has_target, noise_scale=0, iter_seed=0, clamp_lo=0, clamp_hi=0}){
const buf32 = new ArrayBuffer(64);
const u32 = new Uint32Array(buf32); const f32 = new Float32Array(buf32);
u32[0]=ni; u32[1]=no; u32[2]=nxt; u32[3]=B;
f32[4]=dt; f32[5]=beta; f32[6]=gamma; f32[7]=mode;
u32[8]=has_topdown; u32[9]=has_target;
f32[10]=noise_scale;
u32[11]=iter_seed;
f32[12]=clamp_lo; f32[13]=clamp_hi; f32[14]=0; f32[15]=0;
this.dev.queue.writeBuffer(buf, 0, buf32);
}
_writeGradParams(buf, {ni, no, B, two_beta, mode_pre, mode_post}){
const buf32 = new ArrayBuffer(32);
const u32 = new Uint32Array(buf32); const f32 = new Float32Array(buf32);
u32[0]=ni; u32[1]=no; u32[2]=0; u32[3]=B;
f32[4]=0; f32[5]=two_beta; f32[6]=mode_pre; f32[7]=mode_post;
this.dev.queue.writeBuffer(buf, 0, buf32);
}
_writeAuxParams(buf, {B, O, c, mode}){
const buf32 = new ArrayBuffer(32);
const u32 = new Uint32Array(buf32); const f32 = new Float32Array(buf32);
u32[0]=B; u32[1]=O; u32[2]=0; u32[3]=0;
f32[4]=c; f32[5]=mode; f32[6]=0; f32[7]=0;
this.dev.queue.writeBuffer(buf, 0, buf32);
}
uploadWeights(W, b){
const q = this.dev.queue;
for(let l=0; l<this.L; l++){
q.writeBuffer(this.bufW[l], 0, W[l].buffer, W[l].byteOffset, W[l].byteLength);
q.writeBuffer(this.bufB[l], 0, b[l].buffer, b[l].byteOffset, b[l].byteLength);
}
}
// HPSN: set per-neuron time constants. If tauMax > tauMin > 0, samples Uniform[tauMin, tauMax].
// Otherwise fills with constant scalarTau (backward-compat with old fixed-dt EqProp).
setTau(layerIdx, scalarTau, tauMin=0, tauMax=0, seed=42){
const no = this.sizes[layerIdx+1];
const arr = new Float32Array(no);
if(tauMax > tauMin && tauMin > 0){
// Deterministic LCG for reproducible per-neuron tau distribution.
let s = (seed>>>0) || 1;
const rng = ()=>{ s = (Math.imul(s, 1664525) + 1013904223) >>> 0; return s/4294967296; };
for(let i=0;i<no;i++) arr[i] = tauMin + rng()*(tauMax - tauMin);
} else {
arr.fill(scalarTau);
}
this.dev.queue.writeBuffer(this.bufTau[layerIdx], 0, arr.buffer, arr.byteOffset, arr.byteLength);
return arr; // return so caller can inspect distribution
}
// Convenience: set ALL layers to the same (scalar or distribution) tau spec.
setAllTau(scalarTau, tauMin=0, tauMax=0, seed=42){
for(let l=0; l<this.L; l++) this.setTau(l, scalarTau, tauMin, tauMax, seed + l*1000);
}
uploadInputs(X, T){
const q = this.dev.queue;
q.writeBuffer(this.bufWin, 0, X.buffer, X.byteOffset, X.byteLength);
q.writeBuffer(this.bufTgt, 0, T.buffer, T.byteOffset, T.byteLength);
}
_initAllPhases(enc){
const L = this.L;
const MAX_WG_X = 65535;
for(let phase=0; phase<3; phase++){
for(let l=1; l<=L; l++){
const n = this.B * this.sizes[l];
const wgTotal = Math.ceil(n/64);
// 2D dispatch when wgTotal exceeds per-dim limit
const wgX = Math.min(wgTotal, MAX_WG_X);
const wgY = Math.ceil(wgTotal / MAX_WG_X);
const pass = enc.beginComputePass();
pass.setPipeline(this.pipeInit);
pass.setBindGroup(0, this.bgInit[phase][l-1]);
pass.dispatchWorkgroups(wgX, wgY);
pass.end();
}
}
}
// CRITICAL: each layer update must be in its OWN compute pass so the GPU sees
// a barrier between writes to layer l's state and reads of that state by layer l+1.
// WebGPU has no implicit synchronization between dispatches within a single pass.
_runPhaseRelax(enc, phase, iters){
const L = this.L, B = this.B;
for(let t=0; t<iters; t++){
for(let l=1; l<=L; l++){
const pass = enc.beginComputePass();
pass.setPipeline(this.pipeRelax);
pass.setBindGroup(0, this.bgR[phase][l-1]);
const no = this.sizes[l];
pass.dispatchWorkgroups(Math.ceil(no/8), Math.ceil(B/8));
pass.end();
}
}
}
// Write all uniform buffers for the pass.
// noiseScale > 0 enables sEqProp; seedBase is added to iteration counter for per-call variation.
_writeAllUniformsForPass(dt, beta, noiseScale=0, seedBase=0){
const S=this.sizes, L=this.L, B=this.B, gam=this.gamma, mf=this.modeFlag;
const phaseBetas = [0, +beta, -beta]; // free, plus, minus
const ns = (typeof this.noiseScale === 'number') ? this.noiseScale : noiseScale;
const sb = (typeof this.noiseSeedBase === 'number') ? this.noiseSeedBase : seedBase;
const cLo = this.driveClampLo || 0;
const cHi = this.driveClampHi || 0;
// Relax uniforms (per phase, per layer). iter_seed is incremented per call below.
for(let phase=0; phase<3; phase++){
for(let l=1; l<=L; l++){
const isOut = (l === L);
const isHid = !isOut;
const ni = S[l-1], no = S[l], nxt = isHid ? S[l+1] : 0;
const phaseBeta = (isOut) ? phaseBetas[phase] : 0;
this._writeRelaxParams(this.bufP_relax[phase][l-1], {
ni, no, nxt, B, dt, beta: phaseBeta, gamma: gam, mode: mf,
has_topdown: isHid ? 1 : 0, has_target: isOut ? 1 : 0,
noise_scale: ns,
iter_seed: (sb + phase * 7919 + (l-1) * 1009) >>> 0,
clamp_lo: cLo, clamp_hi: cHi,
});
}
}
for(let l=1; l<=L; l++){
this._writeRelaxParams(this.bufP_init[l-1], {
ni: S[l-1], no: S[l], nxt: 0, B, dt, beta: 0, gamma: gam, mode: mf, has_topdown: 0, has_target: 0,
noise_scale: 0, iter_seed: 0, // init kernel doesn't use noise
clamp_lo: 0, clamp_hi: 0, // init kernel doesn't run drive — clamp irrelevant
});
}
}
// Tier A — runtime setter for drive clamp. Pass (0,0) to disable.
setDriveClamp(lo, hi){
this.driveClampLo = lo;
this.driveClampHi = hi;
}
// sEqProp: set per-pass noise scale and seed base. Call before runFreeAndReadOutputs / runOnePass.
setSEqPropNoise(noiseScale, seedBase){
this.noiseScale = noiseScale;
this.noiseSeedBase = (seedBase >>> 0) || 0;
}
_runReward(enc){
this._writeAuxParams(this.bufP_rew, {B: this.B, O: this.O, c: 0, mode: this.modeFlag});
const pass = enc.beginComputePass();
pass.setPipeline(this.pipeReward);
pass.setBindGroup(0, this.bgRew);
pass.dispatchWorkgroups(Math.ceil(this.B/64));
pass.end();
}
_runAdaptation(enc, adpC, adpSteps){
if(adpSteps <= 0 || this.mode === 'fhn') return; // skip adaptation in FHN mode
const L = this.L;
for(let l=1; l<=L; l++){
this._writeAuxParams(this.bufP_adapt[l-1], {B: this.B, O: this.O, c: adpC, mode: this.modeFlag});
}
const MAX_WG_X = 65535;
for(let a=0; a<adpSteps; a++){
for(let l=1; l<=L; l++){
const n = this.B * this.sizes[l];
const wgTotal = Math.ceil(n/64);
const wgX = Math.min(wgTotal, MAX_WG_X);
const wgY = Math.ceil(wgTotal / MAX_WG_X);
const pass = enc.beginComputePass();
pass.setPipeline(this.pipeAdapt);
pass.setBindGroup(0, this.bgAdapt[l-1]);
pass.dispatchWorkgroups(wgX, wgY);
pass.end();
}
}
}
_runGrad(enc, beta){
const L = this.L, B = this.B;
// Grad uniforms per layer transition
for(let l=0; l<L; l++){
const ni = this.sizes[l], no = this.sizes[l+1];
// Determine ρ modes:
// FHN: both pre/post are clip (1).
// Adaptive: σ for both EXCEPT input layer (l=0) where Win is treated with σ (mode_pre=0 always since adaptive).
const mode_pre = this.modeFlag; // 0 for adaptive, 1 for fhn — applies to all states
const mode_post = this.modeFlag;
this._writeGradParams(this.bufP_grad[l], {ni, no, B, two_beta: 2*beta, mode_pre, mode_post});
}
for(let l=0; l<L; l++){
const ni = this.sizes[l], no = this.sizes[l+1];
const pass = enc.beginComputePass();
pass.setPipeline(this.pipeGW);
pass.setBindGroup(0, this.bgG[l]);
pass.dispatchWorkgroups(Math.ceil(ni/8), Math.ceil(no/8));
pass.setPipeline(this.pipeGB);
pass.setBindGroup(0, this.bgG[l]);
pass.dispatchWorkgroups(Math.ceil(no/64));
pass.end();
}
}
async runFreeAndReadOutputs(iters, dt){
if(!this.useHPSN){
if(this._lastTauDt !== dt){ this.setAllTau(dt); this._lastTauDt = dt; }
}
this._writeAllUniformsForPass(dt, 0); // beta=0 → all phases free
const enc = this.dev.createCommandEncoder();
this._initAllPhases(enc);
this._runPhaseRelax(enc, PHASE_F, iters);
const O = this.O;
enc.copyBufferToBuffer(this.bufU[PHASE_F][this.L-1], 0, this.rbUoF, 0, this.B*O*4);
this.dev.queue.submit([enc.finish()]);
await this.rbUoF.mapAsync(GPUMapMode.READ);
const r = new Float32Array(this.rbUoF.getMappedRange().slice(0));
this.rbUoF.unmap();
return r;
}
// Free-phase relax, then read back the activations of an arbitrary internal layer (l in [1..L]).
async runFreeAndReadLayer(iters, dt, layerIdx){
if(!this.useHPSN){
if(this._lastTauDt !== dt){ this.setAllTau(dt); this._lastTauDt = dt; }
}
if(layerIdx < 1 || layerIdx > this.L) throw new Error('layerIdx out of range');
this._writeAllUniformsForPass(dt, 0);
const enc = this.dev.createCommandEncoder();
this._initAllPhases(enc);
this._runPhaseRelax(enc, PHASE_F, iters);
const size = this.B * this.sizes[layerIdx] * 4;
const rb = this.dev.createBuffer({size, usage: GPUBufferUsage.COPY_DST|GPUBufferUsage.MAP_READ});
enc.copyBufferToBuffer(this.bufU[PHASE_F][layerIdx-1], 0, rb, 0, size);
this.dev.queue.submit([enc.finish()]);
await rb.mapAsync(GPUMapMode.READ);
const r = new Float32Array(rb.getMappedRange().slice(0));
rb.unmap(); rb.destroy?.();
return r;
}
async runOnePassGetGradients({itF=8, itN=5, dt=0.7, beta=0.5, adpC=0.15, adpSteps=3}={}){
if(this.mode === 'fhn') adpSteps = 0;
// HPSN backward-compat: when not using heterogeneous-τ, refresh Tau to match runtime dt.
// When useHPSN=true, the user-set heterogeneous distribution is preserved (Tau not overwritten).
if(!this.useHPSN){
if(this._lastTauDt !== dt){ this.setAllTau(dt); this._lastTauDt = dt; }
}
this._writeAllUniformsForPass(dt, beta);
const enc = this.dev.createCommandEncoder();
this._initAllPhases(enc);
this._runPhaseRelax(enc, PHASE_F, itF);
this._runPhaseRelax(enc, PHASE_P, itN);
this._runPhaseRelax(enc, PHASE_M, itN);
this._runReward(enc);
this._runAdaptation(enc, adpC, adpSteps);
this._runGrad(enc, beta);
// Readback all gradients (separate buffers per layer) + Uo_free
for(let l=0; l<this.L; l++){
enc.copyBufferToBuffer(this.bufGW[l], 0, this.rbGW[l], 0, this.sizes[l+1]*this.sizes[l]*4);
enc.copyBufferToBuffer(this.bufGB[l], 0, this.rbGB[l], 0, this.sizes[l+1]*4);
}
enc.copyBufferToBuffer(this.bufU[PHASE_F][this.L-1], 0, this.rbUoF, 0, this.B*this.O*4);
this.dev.queue.submit([enc.finish()]);
const maps = [this.rbUoF.mapAsync(GPUMapMode.READ)];
for(let l=0; l<this.L; l++){ maps.push(this.rbGW[l].mapAsync(GPUMapMode.READ)); maps.push(this.rbGB[l].mapAsync(GPUMapMode.READ)); }
await Promise.all(maps);
const uoF = new Float32Array(this.rbUoF.getMappedRange().slice(0));
this.rbUoF.unmap();
const gW = [], gB = [];
for(let l=0; l<this.L; l++){
gW.push(new Float32Array(this.rbGW[l].getMappedRange().slice(0)));
gB.push(new Float32Array(this.rbGB[l].getMappedRange().slice(0)));
this.rbGW[l].unmap(); this.rbGB[l].unmap();
}
return {gW, gB, uoF};
}
destroy(){
const bufs = [this.bufWin, this.bufTgt, this.bufR, this.bufDummyR, this.bufDummyRW1, this.bufDummyRW2, this.bufDummyRW3, this.bufG, this.rbG, this.rbUoF, this.bufP_rew];
for(const arr of [this.bufW, this.bufB, this.bufGW, this.bufGB, this.rbGW, this.rbGB, this.bufP_init, this.bufP_grad, this.bufP_adapt, this.bufTau]) bufs.push(...arr);
for(const ph of this.bufU) bufs.push(...ph);
for(const ph of this.bufP_relax) bufs.push(...ph);
for(const v of bufs) if(v && v.destroy) try{ v.destroy(); }catch(e){}
}
}
// Multi-layer AdaGO optimizer.
export class AdaGODeep {
constructor(sizes, {OW_K=8}={}){
this.sizes = sizes; this.L = sizes.length - 1;
this.MW=[]; this.MB=[]; this.vW=new Float64Array(this.L); this.vB=new Float64Array(this.L);
this.OW=new Array(this.L).fill(null); this.OW_K=OW_K; this.bc=0;
for(let l=0; l<this.L; l++){
this.MW.push(new Float64Array(sizes[l+1]*sizes[l]));
this.MB.push(new Float64Array(sizes[l+1]));
}
}
step(l, W, B, gW, gB, bs, lr){
const eps=1e-8, mu=0.9, gam=1.0;
const ni=this.sizes[l], no=this.sizes[l+1];
let gn2=0; for(let k=0;k<W.length;k++){ const g=gW[k]/bs; this.MW[l][k]=mu*this.MW[l][k]+(1-mu)*g; gn2+=g*g; }
const gn=Math.sqrt(gn2); this.vW[l]+=Math.min(gn2,gam*gam);
if(!this.OW[l] || this.bc%this.OW_K===0) this.OW[l]=orthCPU(this.MW[l], no, ni);
const O=this.OW[l];
const alpha=Math.max(eps, lr*Math.min(gn,gam)/(Math.sqrt(this.vW[l])+eps));
for(let k=0;k<W.length;k++) W[k]+= alpha*O[k];
let bn2=0; for(let k=0;k<B.length;k++){ const g=gB[k]/bs; this.MB[l][k]=mu*this.MB[l][k]+(1-mu)*g; bn2+=g*g; }
const bn=Math.sqrt(bn2); this.vB[l]+=Math.min(bn2,gam*gam);
const ba=Math.max(eps, lr*Math.min(bn,gam)/(Math.sqrt(this.vB[l])+eps));
for(let k=0;k<B.length;k++) B[k]+= ba*(bn>0?this.MB[l][k]/bn:0);
}
endBatch(){ this.bc++; }
}
// Adam (with optional weight decay → AdamW). EqProp gives an ascent direction, so we += step.
export class Adam {
constructor(sizes, {beta1=0.9, beta2=0.999, eps=1e-8, weightDecay=0}={}){
this.sizes=sizes; this.L=sizes.length-1; this.beta1=beta1; this.beta2=beta2; this.eps=eps;
this.wd = weightDecay; // AdamW-style decoupled weight decay (applied to W only, not bias)
this.mW=[]; this.vW=[]; this.mB=[]; this.vB=[]; this.t=0;
for(let l=0; l<this.L; l++){
this.mW.push(new Float64Array(sizes[l+1]*sizes[l]));
this.vW.push(new Float64Array(sizes[l+1]*sizes[l]));
this.mB.push(new Float64Array(sizes[l+1]));
this.vB.push(new Float64Array(sizes[l+1]));
}
}
step(l, W, B, gW, gB, bs, lr){
this.t++;
const b1=this.beta1, b2=this.beta2, eps=this.eps;
const bc1 = 1 - Math.pow(b1, this.t), bc2 = 1 - Math.pow(b2, this.t);
const wd = this.wd;
for(let k=0;k<W.length;k++){
const g = gW[k]/bs;
this.mW[l][k] = b1*this.mW[l][k] + (1-b1)*g;
this.vW[l][k] = b2*this.vW[l][k] + (1-b2)*g*g;
const m_hat = this.mW[l][k]/bc1, v_hat = this.vW[l][k]/bc2;
// AdamW: decoupled decay (W *= 1 - lr*wd) ; EqProp ascent => + m̂/√v̂
if(wd > 0) W[k] *= (1 - lr * wd);
W[k] += lr * m_hat / (Math.sqrt(v_hat) + eps);
}
for(let k=0;k<B.length;k++){
const g = gB[k]/bs;
this.mB[l][k] = b1*this.mB[l][k] + (1-b1)*g;
this.vB[l][k] = b2*this.vB[l][k] + (1-b2)*g*g;
const m_hat = this.mB[l][k]/bc1, v_hat = this.vB[l][k]/bc2;
B[k] += lr * m_hat / (Math.sqrt(v_hat) + eps);
}
}
endBatch(){}
}
// Muon optimizer (Keller Jordan's MomentUm Orthogonalized by Newton-schulz).
// Steps in sign-of-singular-values direction. Per-step Newton-Schulz quintic on momentum.
// Uses shape-aware scaling: step = lr * sqrt(max(m,n)/min(m,n)) * Orth(M).
// Coefficients from K. Jordan's original Muon: a, b, c chosen so f(x)=ax+bx³+cx⁵ pushes
// singular values toward 1 (with controlled overshoot).
function muonOrth(M_in, m, n, iters=5){
// Normalize so spectral norm ≤ 1 (use Frobenius norm as upper bound).
// Then apply iters of X ← a X + b (XXᵀ)X + c (XXᵀ)² X (Muon's quintic Newton-Schulz)
let X = new Float64Array(M_in.length); for(let k=0;k<M_in.length;k++) X[k]=M_in[k];
let nrm=0; for(const x of X) nrm += x*x; nrm = Math.sqrt(nrm) + 1e-30;
for(let k=0;k<X.length;k++) X[k] /= nrm;
// Decide which side is smaller (use m≥n form: work with X[m,n], operate on (XᵀX) [n,n] is smaller if n<m)
// For generality just do the quintic in M form using square root tricks. Use the m≥n path; if n>m, transpose.
const transp = (n > m);
let R = m, C = n;
if(transp){
// swap to make R ≥ C
const T = new Float64Array(n*m);
for(let i=0;i<m;i++) for(let j=0;j<n;j++) T[j*m+i] = X[i*n+j];
X = T; R = n; C = m;
}
const a = 3.4445, b = -4.7750, c = 2.0315;
// helper: matmul A(p,q)·B(q,r) → out(p,r)
function mm(A, B, p, q, r){
const O = new Float64Array(p*r);
for(let i=0;i<p;i++) for(let k=0;k<q;k++){ const aa=A[i*q+k]; if(aa) for(let j=0;j<r;j++) O[i*r+j]+=aa*B[k*r+j]; }
return O;
}
function transpose(A, p, q){ const T=new Float64Array(p*q); for(let i=0;i<p;i++) for(let j=0;j<q;j++) T[j*p+i]=A[i*q+j]; return T; }
for(let it=0; it<iters; it++){
// We want X ← a X + b X(XᵀX) + c X(XᵀX)²
// Compute G = XᵀX (C×C)
const Xt = transpose(X, R, C);
const G = mm(Xt, X, C, R, C); // (C×C)
const G2 = mm(G, G, C, C, C); // (C×C)
const XG = mm(X, G, R, C, C); // (R×C)
const XG2 = mm(X, G2, R, C, C); // (R×C)
const Y = new Float64Array(R*C);
for(let k=0;k<R*C;k++) Y[k] = a*X[k] + b*XG[k] + c*XG2[k];
X = Y;
}
// Transpose back if needed
if(transp){
const O = new Float64Array(m*n);
for(let i=0;i<R;i++) for(let j=0;j<C;j++) O[j*m+i] = X[i*C+j]; // X was C×R after transp → output m×n
// wait: when transp we used X[n×m], R=n C=m. So X[i*C+j] is X[n][m]. We want output[m,n] = transpose.
// O[j*n+i] = X[i*m+j] would mean output_row_j_col_i = X_row_i_col_j. Let me redo.
const out = new Float64Array(m*n);
// X is R×C = n×m, so X[i,j] for i in 0..n, j in 0..m. We want M_orth[m,n] = (X_transposed)[a,b] = X[b,a].
for(let a=0;a<m;a++) for(let b=0;b<n;b++) out[a*n+b] = X[b*m+a];
return new Float32Array(out);
}
return new Float32Array(X);
}
export class Muon {
constructor(sizes, {beta=0.95, weightDecay=0, iters=5}={}){
this.sizes=sizes; this.L=sizes.length-1; this.beta=beta; this.wd=weightDecay; this.iters=iters;
this.MW=[]; this.mB=[];
for(let l=0; l<this.L; l++){
this.MW.push(new Float64Array(sizes[l+1]*sizes[l]));
this.mB.push(new Float64Array(sizes[l+1]));
}
}
step(l, W, B, gW, gB, bs, lr){
const beta=this.beta, wd=this.wd;
const no = this.sizes[l+1], ni = this.sizes[l];
// Momentum update (Muon uses Nesterov-style momentum)
for(let k=0;k<W.length;k++){
const g = gW[k]/bs;
this.MW[l][k] = beta*this.MW[l][k] + g;
}
// Orthogonalize momentum via quintic NS
const O = muonOrth(this.MW[l], no, ni, this.iters);
// Shape-aware scaling: lr · sqrt(max/min)
const scale = lr * Math.sqrt(Math.max(no, ni) / Math.min(no, ni));
// Step (ASCEND since EqProp gives ascent direction)
for(let k=0;k<W.length;k++){
if(wd>0) W[k] *= (1 - lr*wd);
W[k] += scale * O[k];
}
// Bias: plain momentum (Muon spec says biases get separate Adam-like; here just SGD-momentum for simplicity)
for(let k=0;k<B.length;k++){
const g = gB[k]/bs;
this.mB[l][k] = beta*this.mB[l][k] + g;
B[k] += lr * this.mB[l][k];
}
}
endBatch(){}
}
// Lion optimizer (sign of momentum). Often outperforms Adam for some tasks, less memory.
export class Lion {
constructor(sizes, {beta1=0.9, beta2=0.99, weightDecay=0}={}){
this.sizes=sizes; this.L=sizes.length-1; this.beta1=beta1; this.beta2=beta2; this.wd=weightDecay;
this.mW=[]; this.mB=[];
for(let l=0; l<this.L; l++){
this.mW.push(new Float64Array(sizes[l+1]*sizes[l]));
this.mB.push(new Float64Array(sizes[l+1]));
}
}
step(l, W, B, gW, gB, bs, lr){
const b1=this.beta1, b2=this.beta2, wd=this.wd;
for(let k=0;k<W.length;k++){
const g = gW[k]/bs;
// update direction: sign(b1*m + (1-b1)*g)
const c = b1*this.mW[l][k] + (1-b1)*g;
const u = c >= 0 ? 1 : -1;
if(wd>0) W[k] *= (1 - lr*wd);
W[k] += lr * u;
// momentum update with b2
this.mW[l][k] = b2*this.mW[l][k] + (1-b2)*g;
}
for(let k=0;k<B.length;k++){
const g = gB[k]/bs;
const c = b1*this.mB[l][k] + (1-b1)*g;
const u = c >= 0 ? 1 : -1;
B[k] += lr * u;
this.mB[l][k] = b2*this.mB[l][k] + (1-b2)*g;
}
}
endBatch(){}
}
|