| |
| #include "acl_common.h" |
| #include <aclnnop/aclnn_add.h> |
| #include <cstdio> |
| #include <vector> |
|
|
| int main() { |
| ACL_CHECK(aclInit(nullptr)); |
| ACL_CHECK(aclrtSetDevice(0)); |
| aclrtContext ctx; |
| ACL_CHECK(aclrtCreateContext(&ctx, 0)); |
| aclrtStream stream; |
| ACL_CHECK(aclrtCreateStream(&stream)); |
|
|
| |
| const int64_t N = 4; |
| std::vector<float> a_host = {1.0f, 2.0f, 3.0f, 4.0f}; |
| std::vector<float> b_host = {10.0f, 20.0f, 30.0f, 40.0f}; |
| std::vector<float> out_host(N, 0.0f); |
|
|
| DeviceBuffer a_dev(N * sizeof(float)); |
| DeviceBuffer b_dev(N * sizeof(float)); |
| DeviceBuffer out_dev(N * sizeof(float)); |
|
|
| ACL_CHECK(aclrtMemcpy(a_dev.get(), N * 4, a_host.data(), N * 4, ACL_MEMCPY_HOST_TO_DEVICE)); |
| ACL_CHECK(aclrtMemcpy(b_dev.get(), N * 4, b_host.data(), N * 4, ACL_MEMCPY_HOST_TO_DEVICE)); |
|
|
| auto a_t = make_contig_tensor(a_dev.get(), ACL_FLOAT, {N}); |
| auto b_t = make_contig_tensor(b_dev.get(), ACL_FLOAT, {N}); |
| auto out_t = make_contig_tensor(out_dev.get(), ACL_FLOAT, {N}); |
|
|
| |
| float alpha_val = 1.0f; |
| aclScalar* alpha = aclCreateScalar(&alpha_val, ACL_FLOAT); |
|
|
| uint64_t ws_size = 0; |
| aclOpExecutor* executor = nullptr; |
| ACLNN_CHECK(aclnnAddGetWorkspaceSize(a_t.get(), b_t.get(), alpha, out_t.get(), &ws_size, &executor)); |
|
|
| DeviceBuffer ws; |
| if (ws_size > 0) ws.alloc(ws_size); |
| ACLNN_CHECK(aclnnAdd(ws.get(), ws_size, executor, stream)); |
|
|
| ACL_CHECK(aclrtSynchronizeStream(stream)); |
|
|
| ACL_CHECK(aclrtMemcpy(out_host.data(), N * 4, out_dev.get(), N * 4, ACL_MEMCPY_DEVICE_TO_HOST)); |
|
|
| printf("hello_acl: "); |
| for (int i = 0; i < N; i++) printf("%.1f ", out_host[i]); |
| printf("\n"); |
|
|
| bool ok = (out_host[0] == 11.0f && out_host[1] == 22.0f && |
| out_host[2] == 33.0f && out_host[3] == 44.0f); |
| printf(ok ? "PASS\n" : "FAIL\n"); |
|
|
| aclDestroyScalar(alpha); |
| ACL_CHECK(aclrtDestroyStream(stream)); |
| ACL_CHECK(aclrtDestroyContext(ctx)); |
| ACL_CHECK(aclrtResetDevice(0)); |
| aclFinalize(); |
| return ok ? 0 : 1; |
| } |
|
|