// test_rms_norm.cpp — verify aclnnRmsNorm against Python reference. #include "acl_common.h" #include "acl_runtime.h" #include "aclnn_ops.h" #include #include #include #include #include static float bf16_to_float(uint16_t x) { uint32_t u = (uint32_t)x << 16; float f; std::memcpy(&f, &u, 4); return f; } static std::vector read_file(const std::string& path) { std::ifstream f(path, std::ios::binary | std::ios::ate); if (!f) { fprintf(stderr, "open %s failed\n", path.c_str()); std::abort(); } size_t sz = f.tellg(); f.seekg(0); std::vector v(sz); f.read((char*)v.data(), sz); return v; } int main() { const std::string data = "tests/rms_norm_data"; // Parse shape.txt int64_t N = 0, D = 0; double eps = 1e-6; { std::ifstream f(data + "/shape.txt"); std::string line; while (std::getline(f, line)) { auto eq = line.find('='); if (eq == std::string::npos) continue; auto k = line.substr(0, eq); auto v = line.substr(eq + 1); if (k == "N") N = std::atoll(v.c_str()); else if (k == "D") D = std::atoll(v.c_str()); else if (k == "eps") eps = std::atof(v.c_str()); } } printf("Shape: N=%ld D=%ld eps=%g\n", N, D, eps); AclRuntime rt; rt.init(0); auto x_host = read_file(data + "/x.bin"); auto gamma_host = read_file(data + "/gamma.bin"); auto y_ref_host = read_file(data + "/y_ref.bin"); DeviceBuffer x_dev(N * D * 2); DeviceBuffer gamma_dev(D * 2); DeviceBuffer y_dev(N * D * 2); DeviceBuffer rstd_dev(N * 4); // F32 ACL_CHECK(aclrtMemcpy(x_dev.get(), x_host.size(), x_host.data(), x_host.size(), ACL_MEMCPY_HOST_TO_DEVICE)); ACL_CHECK(aclrtMemcpy(gamma_dev.get(), gamma_host.size(), gamma_host.data(), gamma_host.size(), ACL_MEMCPY_HOST_TO_DEVICE)); auto t_x = make_contig_tensor(x_dev.get(), ACL_BF16, {N, D}); auto t_gamma = make_contig_tensor(gamma_dev.get(), ACL_BF16, {D}); auto t_y = make_contig_tensor(y_dev.get(), ACL_BF16, {N, D}); auto t_rstd = make_contig_tensor(rstd_dev.get(), ACL_FLOAT, {N}); rms_norm(rt.stream(), t_x.get(), t_gamma.get(), eps, t_y.get(), t_rstd.get()); rt.sync(); std::vector y_cxx(N * D); ACL_CHECK(aclrtMemcpy(y_cxx.data(), N * D * 2, y_dev.get(), N * D * 2, ACL_MEMCPY_DEVICE_TO_HOST)); auto* y_ref = (const uint16_t*)y_ref_host.data(); // Compare double l2_d = 0, l2_r = 0, max_abs = 0; for (int i = 0; i < N * D; i++) { float a = bf16_to_float(y_cxx[i]); float b = bf16_to_float(y_ref[i]); l2_d += (a-b)*(a-b); l2_r += b*b; if (std::abs(a-b) > max_abs) max_abs = std::abs(a-b); } double rel = std::sqrt(l2_d) / (std::sqrt(l2_r) + 1e-10); printf("L2 diff = %.6f, L2 ref = %.6f, relative = %.6e, max abs = %.6f\n", std::sqrt(l2_d), std::sqrt(l2_r), rel, max_abs); printf("y_cxx[0, :8]: "); for (int i = 0; i < 8; i++) printf("%.4f ", bf16_to_float(y_cxx[i])); printf("\ny_ref[0, :8]: "); for (int i = 0; i < 8; i++) printf("%.4f ", bf16_to_float(y_ref[i])); printf("\n"); bool ok = rel < 1e-2; printf("\n%s\n", ok ? "=== test_rms_norm PASS ===" : "=== test_rms_norm FAIL ==="); return ok ? 0 : 1; }