File size: 8,696 Bytes
4b9fefd | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 | // test_layer_forward.cpp — integration test for one full transformer layer via engine.h.
//
// Chain: embed_5_tokens → attention_forward (prefill, past=0) → +residual → moe_forward → +residual
// Expected: final output matches moe_data/final_out.bin within BF16 precision (rel < 5e-2).
#include "acl_common.h"
#include "acl_runtime.h"
#include "aclnn_ops.h"
#include "device_weights.h"
#include "engine.h"
#include "model_config.h"
#include "safetensors_loader.h"
#include <cmath>
#include <cstdio>
#include <cstring>
#include <fstream>
#include <vector>
static float bf16_to_float(uint16_t x) {
uint32_t u = (uint32_t)x << 16; float f; std::memcpy(&f, &u, 4); return f;
}
static std::vector<uint8_t> read_file(const std::string& p) {
std::ifstream f(p, std::ios::binary | std::ios::ate); size_t s = f.tellg();
f.seekg(0); std::vector<uint8_t> v(s); f.read((char*)v.data(), s); return v;
}
// Add: out = a + b (BF16).
static void bf16_add(aclrtStream stream, aclTensor* a, aclTensor* b, aclTensor* out) {
float alpha = 1.0f; aclScalar* al = aclCreateScalar(&alpha, ACL_FLOAT);
uint64_t ws = 0; aclOpExecutor* e = nullptr;
ACLNN_CHECK(aclnnAddGetWorkspaceSize(a, b, al, out, &ws, &e));
DeviceBuffer wb; if (ws > 0) wb.alloc(ws);
ACLNN_CHECK(aclnnAdd(wb.get(), ws, e, stream));
aclDestroyScalar(al);
}
int main() {
const std::string model_dir = "/path/to/Qwen3-235B-A22B-Instruct-2507-BF16";
const std::string attn_data = "tests/attn_data";
const std::string moe_data = "tests/moe_data";
ModelConfig cfg;
if (!cfg.load_from_json(model_dir + "/config.json")) return 1;
cfg.compute_derived(1, 0);
const int64_t D = cfg.hidden_size;
const int64_t Hq = cfg.n_heads_per_rank;
const int64_t Hkv = cfg.n_kv_heads_per_rank;
const int64_t Dh = cfg.head_dim;
const int64_t Q_DIM = Hq * Dh;
const int64_t KV_DIM = Hkv * Dh;
const int64_t I = cfg.i_per_rank;
const int64_t E = cfg.num_experts;
const int64_t K = cfg.num_experts_per_tok;
printf("Dims: D=%ld Q_DIM=%ld KV_DIM=%ld I=%ld E=%ld K=%ld\n", D, Q_DIM, KV_DIM, I, E, K);
SafetensorsLoader st;
if (!st.open(model_dir)) return 1;
AclRuntime rt;
rt.init(0);
DeviceWeightsLoader dw(st, cfg);
SharedWeights shared;
LayerAttnWeights attn;
LayerMoEWeights moe;
printf("Loading weights...\n");
if (!dw.load_shared(shared)) return 1;
if (!dw.load_attention(0, attn)) return 1;
if (!dw.load_moe(0, rt.stream(), moe)) return 1;
rt.sync();
// ---- Load 5 prefill tokens ----
auto tok_raw = read_file(attn_data + "/token_ids.bin");
int32_t S = *(int32_t*)tok_raw.data();
std::vector<int32_t> tokens(S);
std::memcpy(tokens.data(), tok_raw.data() + 4, S * 4);
printf("S=%d tokens=[", S); for (auto t : tokens) printf("%d,", t); printf("]\n");
// ---- Embed ----
DeviceBuffer tok_dev(S * 4);
ACL_CHECK(aclrtMemcpy(tok_dev.get(), S * 4, tokens.data(), S * 4, ACL_MEMCPY_HOST_TO_DEVICE));
auto t_tok = make_contig_tensor(tok_dev.get(), ACL_INT32, {S});
auto t_embed_w = make_contig_tensor(shared.embed_tokens.get(), ACL_BF16, {cfg.vocab_size, D});
DeviceBuffer x_dev(S * D * 2); // residual / input to layer
auto t_x = make_contig_tensor(x_dev.get(), ACL_BF16, {S, D});
index_select(rt.stream(), t_embed_w.get(), 0, t_tok.get(), t_x.get());
rt.sync();
// ---- Scratch buffers for attention_forward ----
const int64_t MAX_LEN = 128;
DeviceBuffer k_cache(MAX_LEN * KV_DIM * 2), v_cache(MAX_LEN * KV_DIM * 2);
DeviceBuffer q_sc(S * Q_DIM * 2), k_sc(S * KV_DIM * 2), v_sc(S * KV_DIM * 2);
DeviceBuffer xn_sc(S * D * 2), rstd_sc(S * std::max(Hq, Hkv) * 4);
DeviceBuffer rope_sc(1 * S * Hq * Dh * 2);
DeviceBuffer attn_fias_sc(S * Q_DIM * 2); // FIAS output buffer (before o_proj)
DeviceBuffer attn_out_dev(S * D * 2);
// ---- Causal mask (2048x2048) for prefill ----
const int64_t MASK = 2048;
DeviceBuffer mask_dev(MASK * MASK);
std::vector<uint8_t> mh(MASK * MASK, 0);
for (int i = 0; i < MASK; i++)
for (int j = i+1; j < MASK; j++) mh[i*MASK + j] = 1;
ACL_CHECK(aclrtMemcpy(mask_dev.get(), MASK*MASK, mh.data(), MASK*MASK, ACL_MEMCPY_HOST_TO_DEVICE));
auto t_mask = make_contig_tensor(mask_dev.get(), ACL_BOOL, {1, 1, MASK, MASK});
// ---- Attention forward ----
attention_forward(
rt.stream(), cfg, attn,
x_dev.get(), S,
/*past_len=*/0, k_cache.get(), v_cache.get(), MAX_LEN,
t_mask.get(),
q_sc.get(), k_sc.get(), v_sc.get(),
xn_sc.get(), rstd_sc.get(), rope_sc.get(),
attn_fias_sc.get(),
attn_out_dev.get());
rt.sync();
// ---- x1 = x + attn_out (residual) — should match attn_data/final_out.bin ----
DeviceBuffer x1_dev(S * D * 2);
auto t_attn_out = make_contig_tensor(attn_out_dev.get(), ACL_BF16, {S, D});
auto t_x1 = make_contig_tensor(x1_dev.get(), ACL_BF16, {S, D});
bf16_add(rt.stream(), t_x.get(), t_attn_out.get(), t_x1.get());
rt.sync();
auto attn_ref_h = read_file(attn_data + "/final_out.bin");
std::vector<uint16_t> x1_host(S * D);
ACL_CHECK(aclrtMemcpy(x1_host.data(), S*D*2, x1_dev.get(), S*D*2, ACL_MEMCPY_DEVICE_TO_HOST));
auto* ar = (const uint16_t*)attn_ref_h.data();
double al2d=0, al2r=0, amaxd=0;
for (int i = 0; i < S*D; i++) {
float a = bf16_to_float(x1_host[i]), b = bf16_to_float(ar[i]);
al2d += (a-b)*(a-b); al2r += b*b;
if (std::abs(a-b) > amaxd) amaxd = std::abs(a-b);
}
double arel = std::sqrt(al2d) / (std::sqrt(al2r) + 1e-10);
printf(" [attn] x + attn_out vs attn_data/final_out.bin: rel=%.4e max=%.4f\n", arel, amaxd);
// ---- MoE scratch buffers ----
const int64_t TOTAL = S * K;
DeviceBuffer moe_xn(S * D * 2), moe_rstd(S * 4);
DeviceBuffer moe_logits(S * E * 2);
DeviceBuffer moe_topk_w(S * K * 2), moe_topk_idx(S * K * 4), moe_row_idx(S * K * 4);
DeviceBuffer moe_ex_x(TOTAL * D * 2), moe_ex_ri(TOTAL * 4), moe_tpe(E * 8);
DeviceBuffer moe_fwd(TOTAL * 8);
DeviceBuffer moe_gate(TOTAL * I * 2), moe_up(TOTAL * I * 2), moe_down(TOTAL * D * 2);
DeviceBuffer moe_packed(TOTAL * D * 2), moe_weighted(S * K * D * 2);
DeviceBuffer moe_out_dev(S * D * 2);
moe_forward(rt.stream(), cfg, attn, moe,
x1_dev.get(), S,
moe_xn.get(), moe_rstd.get(),
moe_logits.get(),
moe_topk_w.get(), moe_topk_idx.get(), moe_row_idx.get(),
moe_ex_x.get(), moe_ex_ri.get(), moe_tpe.get(),
moe_fwd.get(),
moe_gate.get(), moe_up.get(), moe_down.get(),
moe_packed.get(), moe_weighted.get(),
moe_out_dev.get());
rt.sync();
// ---- x2 = x1 + moe_out (residual) — should match moe_data/final_out.bin ----
DeviceBuffer x2_dev(S * D * 2);
auto t_moe_out = make_contig_tensor(moe_out_dev.get(), ACL_BF16, {S, D});
auto t_x2 = make_contig_tensor(x2_dev.get(), ACL_BF16, {S, D});
bf16_add(rt.stream(), t_x1.get(), t_moe_out.get(), t_x2.get());
rt.sync();
auto moe_ref_h = read_file(moe_data + "/final_out.bin");
std::vector<uint16_t> x2_host(S * D);
ACL_CHECK(aclrtMemcpy(x2_host.data(), S*D*2, x2_dev.get(), S*D*2, ACL_MEMCPY_DEVICE_TO_HOST));
auto* mr = (const uint16_t*)moe_ref_h.data();
double ml2d=0, ml2r=0, mmaxd=0;
for (int i = 0; i < S*D; i++) {
float a = bf16_to_float(x2_host[i]), b = bf16_to_float(mr[i]);
ml2d += (a-b)*(a-b); ml2r += b*b;
if (std::abs(a-b) > mmaxd) mmaxd = std::abs(a-b);
}
double mrel = std::sqrt(ml2d) / (std::sqrt(ml2r) + 1e-10);
printf(" [full] x1 + moe_out vs moe_data/final_out.bin: rel=%.4e max=%.4f\n", mrel, mmaxd);
printf(" x2[0, :4]: %.5f %.5f %.5f %.5f\n",
bf16_to_float(x2_host[0]), bf16_to_float(x2_host[1]), bf16_to_float(x2_host[2]), bf16_to_float(x2_host[3]));
printf(" ref[0, :4]: %.5f %.5f %.5f %.5f\n",
bf16_to_float(mr[0]), bf16_to_float(mr[1]), bf16_to_float(mr[2]), bf16_to_float(mr[3]));
// Tolerance: attn chain 5e-3 (tight, only linear ops); full layer 1e-1 (MoE's discrete topk
// routing amplifies BF16 noise — tiny input changes flip expert selection, magnifying output
// delta. End-to-end CLI correctness is validated by test_chat_flow.sh separately.)
bool pass = (arel < 5e-3) && (mrel < 1e-1);
printf("\n%s\n", pass ? "=== test_layer_forward PASS ===" : "=== test_layer_forward FAIL ===");
return pass ? 0 : 1;
}
|