#!/usr/bin/env python3 """Generate a GMM reference case using torch_npu.npu_grouped_matmul. Dumps: x, w (unpermuted ggml-like layout), group_list, and y_ref to binary files for the C++ POC to load and validate against. """ import os import sys import numpy as np import struct # Enable torch_npu os.environ.setdefault('LD_LIBRARY_PATH', '') import torch import torch_npu torch.npu.set_device(0) torch.manual_seed(42) # Toy Qwen3-like MoE shape: D=64 hidden, I=32 intermediate, E=8 experts, N*K=16 expanded tokens # (small enough to eyeball; large enough to catch layout bugs) D, I, E, TOTAL = 64, 32, 8, 16 # Input x: [TOTAL, D] BF16 — expanded routed tokens x = torch.randn(TOTAL, D, dtype=torch.bfloat16).npu() # Weight w: per-expert [I, D] BF16 — gate/up has this shape in HF # We will stack into [E, I, D] and also provide [E, D, I] permuted for comparison w_per_expert = [torch.randn(I, D, dtype=torch.bfloat16).npu() for _ in range(E)] w_stacked_IDL = torch.stack(w_per_expert, dim=0) # [E, I, D] # group_list: counts of tokens per expert, sum = TOTAL group_list = torch.tensor([3, 2, 1, 2, 1, 3, 2, 2], dtype=torch.int64).npu() assert group_list.sum().item() == TOTAL # Reference: use torch_npu.npu_grouped_matmul # Per cann-recipes: weight needs to be in [E, D, I] for matmul y = x @ w (y shape [total, I]) # i.e. per-expert w is transposed from HF's [I, D] to [D, I] w_transposed = w_stacked_IDL.transpose(1, 2).contiguous() # [E, D, I] # Call GMM: y = x @ w, result [TOTAL, I] y_ref = torch_npu.npu_grouped_matmul( [x], # x list [w_transposed], # weight list (transposed) group_list=group_list, group_type=0, group_list_type=1, # counts split_item=3 # single-in single-out )[0] # unwrap tensor list print("x shape:", x.shape, x.dtype) print("w_stacked_IDL shape:", w_stacked_IDL.shape, w_stacked_IDL.dtype) print("w_transposed shape:", w_transposed.shape) print("group_list:", group_list.cpu().tolist()) print("y_ref shape:", y_ref.shape) print("y_ref[0, 0:4]:", y_ref[0, 0:4].cpu().float().tolist()) # Save binary dumps out_dir = 'tests/poc_data' os.makedirs(out_dir, exist_ok=True) def dump_bf16(name, tensor): path = os.path.join(out_dir, name + '.bin') arr = tensor.contiguous().cpu().view(torch.int16).numpy().astype('int16') with open(path, 'wb') as f: f.write(arr.tobytes()) print(f" wrote {name}.bin: {arr.shape} int16 = BF16 raw, {arr.nbytes} bytes") def dump_int64(name, tensor): path = os.path.join(out_dir, name + '.bin') arr = tensor.contiguous().cpu().numpy().astype('int64') with open(path, 'wb') as f: f.write(arr.tobytes()) print(f" wrote {name}.bin: {arr.shape} int64, {arr.nbytes} bytes") # HF-style weight layout (ggml stores similar): [E, I, D] = what C++ gets from safetensors after stack dump_bf16('x', x) dump_bf16('w_hf_EID', w_stacked_IDL) # C++ input weight (HF layout) dump_bf16('w_ref_EDI', w_transposed) # Already-permuted reference (for debug) dump_int64('group_list', group_list) dump_bf16('y_ref', y_ref) # Also dump shapes header with open(os.path.join(out_dir, 'shapes.txt'), 'w') as f: f.write(f"D={D}\nI={I}\nE={E}\nTOTAL={TOTAL}\n") print("\nAll dumps in:", out_dir) print("\nTo validate: C++ loads w_hf_EID, permutes [0,2,1] to [E,D,I], NZ-casts, calls GMMV4, " "compares output to y_ref.")