File size: 4,386 Bytes
4b9fefd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
#!/usr/bin/env python3
"""Generate MoE layer forward reference for Qwen3-235B layer 0.

Input: hidden_states from attention output (use attn_data/final_out.bin as input — realistic).
Output: hidden_states after MoE + residual.
"""
import os, json, math, torch, torch_npu
from safetensors.torch import load_file

torch.npu.set_device(0)
torch.set_grad_enabled(False)

MODEL_DIR = '/path/to/Qwen3-235B-A22B-Instruct-2507-BF16'
OUT_DIR   = 'tests/moe_data'
os.makedirs(OUT_DIR, exist_ok=True)

cfg = json.load(open(os.path.join(MODEL_DIR, 'config.json')))
D    = cfg['hidden_size']                  # 4096
I    = cfg['moe_intermediate_size']        # 1536
E    = cfg['num_experts']                  # 128
TK   = cfg['num_experts_per_tok']          # 8
eps  = cfg['rms_norm_eps']
norm_topk = cfg.get('norm_topk_prob', True)

# Use attention output as input (more realistic than random)
attn_out_raw = open('tests/attn_data/final_out.bin', 'rb').read()
S = 5
x_in = torch.frombuffer(bytearray(attn_out_raw), dtype=torch.int16).view(1, S, D).view(torch.bfloat16).npu()
print(f"x_in: {x_in.shape}")

# Load required weights for layer 0
idx = json.load(open(os.path.join(MODEL_DIR, 'model.safetensors.index.json')))
wm = idx['weight_map']

needed = [f'model.layers.0.post_attention_layernorm.weight',
          f'model.layers.0.mlp.gate.weight']
for e in range(E):
    for p in ['gate_proj', 'up_proj', 'down_proj']:
        needed.append(f'model.layers.0.mlp.experts.{e}.{p}.weight')

shards = sorted({wm[n] for n in needed})
weights = {}
for sh in shards:
    t = load_file(os.path.join(MODEL_DIR, sh))
    for n in needed:
        if n in t:
            weights[n] = t[n].to('npu')
print("loaded %d tensors from %d shards" % (len(weights), len(shards)))

# Residual = input
residual = x_in

# Post-attention RmsNorm
xn, _ = torch_npu.npu_rms_norm(x_in, weights['model.layers.0.post_attention_layernorm.weight'], epsilon=eps)
xn_flat = xn.view(S, D)  # flatten batch

# Router: logits [S, E]
W_router = weights['model.layers.0.mlp.gate.weight']                # [E, D]
logits = xn_flat @ W_router.t()                                      # [S, E]

# Top-k softmax
topk_logits, topk_idx = logits.topk(TK, dim=-1)                      # both [S, TK]
topk_weights = torch.softmax(topk_logits.float(), dim=-1)             # [S, TK] F32
if norm_topk:
    topk_weights = topk_weights / (topk_weights.sum(dim=-1, keepdim=True) + 1e-20)
topk_weights = topk_weights.to(torch.bfloat16)
topk_idx = topk_idx.to(torch.int32)

print(f"topk_idx[0]: {topk_idx[0].cpu().tolist()}")
print(f"topk_weights[0]: {topk_weights[0].cpu().float().tolist()}")

# MoE forward — loop over tokens (simple reference, not optimized)
out_flat = torch.zeros(S, D, dtype=torch.bfloat16, device='npu')
for s in range(S):
    token = xn_flat[s]                                                # [D]
    acc = torch.zeros(D, dtype=torch.bfloat16, device='npu')
    for k in range(TK):
        e = int(topk_idx[s, k].item())
        w = topk_weights[s, k]
        Wg = weights[f'model.layers.0.mlp.experts.{e}.gate_proj.weight']   # [I, D]
        Wu = weights[f'model.layers.0.mlp.experts.{e}.up_proj.weight']     # [I, D]
        Wd = weights[f'model.layers.0.mlp.experts.{e}.down_proj.weight']   # [D, I]
        gate = token @ Wg.t()      # [I]
        up = token @ Wu.t()
        act = torch.nn.functional.silu(gate) * up
        down = act @ Wd.t()        # [D]
        acc = acc + w * down
    out_flat[s] = acc

moe_out = out_flat.view(1, S, D)
final_out = residual + moe_out
print(f"final_out[0,0,:4] = {final_out[0,0,:4].float().cpu().tolist()}")

# Dump
def dump(name, t):
    p = os.path.join(OUT_DIR, name + '.bin')
    a = t.contiguous().cpu().view(torch.int16).numpy().astype('int16')
    open(p, 'wb').write(a.tobytes())

dump('x_in', x_in)
dump('final_out', final_out)
dump('moe_out', moe_out)
dump('router', W_router)
dump('xn', xn)
dump('topk_w', topk_weights)      # [S, TK] BF16 (normalized)
dump('out_flat', out_flat)        # [S, D] BF16 — moe contrib before residual

# expert_idx as int32 dump (raw bytes)
topk_idx_bytes = topk_idx.contiguous().cpu().numpy().astype('int32').tobytes()
open(os.path.join(OUT_DIR, 'topk_idx.bin'), 'wb').write(topk_idx_bytes)

with open(os.path.join(OUT_DIR, 'shape.txt'), 'w') as f:
    f.write(f"S={S}\nD={D}\nI={I}\nE={E}\nTK={TK}\n")

print(f"\nDumps in {OUT_DIR}")