| |
| """Generate a GMM reference case using torch_npu.npu_grouped_matmul. |
| |
| Dumps: x, w (unpermuted ggml-like layout), group_list, and y_ref to binary files |
| for the C++ POC to load and validate against. |
| """ |
| import os |
| import sys |
| import numpy as np |
| import struct |
|
|
| |
| os.environ.setdefault('LD_LIBRARY_PATH', '') |
| import torch |
| import torch_npu |
|
|
| torch.npu.set_device(0) |
| torch.manual_seed(42) |
|
|
| |
| |
| D, I, E, TOTAL = 64, 32, 8, 16 |
|
|
| |
| x = torch.randn(TOTAL, D, dtype=torch.bfloat16).npu() |
|
|
| |
| |
| w_per_expert = [torch.randn(I, D, dtype=torch.bfloat16).npu() for _ in range(E)] |
| w_stacked_IDL = torch.stack(w_per_expert, dim=0) |
|
|
| |
| group_list = torch.tensor([3, 2, 1, 2, 1, 3, 2, 2], dtype=torch.int64).npu() |
| assert group_list.sum().item() == TOTAL |
|
|
| |
| |
| |
| w_transposed = w_stacked_IDL.transpose(1, 2).contiguous() |
|
|
| |
| y_ref = torch_npu.npu_grouped_matmul( |
| [x], |
| [w_transposed], |
| group_list=group_list, |
| group_type=0, |
| group_list_type=1, |
| split_item=3 |
| )[0] |
|
|
| print("x shape:", x.shape, x.dtype) |
| print("w_stacked_IDL shape:", w_stacked_IDL.shape, w_stacked_IDL.dtype) |
| print("w_transposed shape:", w_transposed.shape) |
| print("group_list:", group_list.cpu().tolist()) |
| print("y_ref shape:", y_ref.shape) |
| print("y_ref[0, 0:4]:", y_ref[0, 0:4].cpu().float().tolist()) |
|
|
| |
| out_dir = 'tests/poc_data' |
| os.makedirs(out_dir, exist_ok=True) |
|
|
| def dump_bf16(name, tensor): |
| path = os.path.join(out_dir, name + '.bin') |
| arr = tensor.contiguous().cpu().view(torch.int16).numpy().astype('int16') |
| with open(path, 'wb') as f: |
| f.write(arr.tobytes()) |
| print(f" wrote {name}.bin: {arr.shape} int16 = BF16 raw, {arr.nbytes} bytes") |
|
|
| def dump_int64(name, tensor): |
| path = os.path.join(out_dir, name + '.bin') |
| arr = tensor.contiguous().cpu().numpy().astype('int64') |
| with open(path, 'wb') as f: |
| f.write(arr.tobytes()) |
| print(f" wrote {name}.bin: {arr.shape} int64, {arr.nbytes} bytes") |
|
|
| |
| dump_bf16('x', x) |
| dump_bf16('w_hf_EID', w_stacked_IDL) |
| dump_bf16('w_ref_EDI', w_transposed) |
| dump_int64('group_list', group_list) |
| dump_bf16('y_ref', y_ref) |
|
|
| |
| with open(os.path.join(out_dir, 'shapes.txt'), 'w') as f: |
| f.write(f"D={D}\nI={I}\nE={E}\nTOTAL={TOTAL}\n") |
|
|
| print("\nAll dumps in:", out_dir) |
| print("\nTo validate: C++ loads w_hf_EID, permutes [0,2,1] to [E,D,I], NZ-casts, calls GMMV4, " |
| "compares output to y_ref.") |
|
|