File size: 1,161 Bytes
4b9fefd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
// acl_runtime.h — per-rank ACL runtime init/teardown.
#pragma once
#include "acl_common.h"
#include <cstdio>

class AclRuntime {
public:
    AclRuntime() = default;
    ~AclRuntime() { shutdown(); }

    bool init(int device_id) {
        if (initialized_) return true;
        device_id_ = device_id;
        ACL_CHECK(aclInit(nullptr));
        ACL_CHECK(aclrtSetDevice(device_id));
        ACL_CHECK(aclrtCreateContext(&ctx_, device_id));
        ACL_CHECK(aclrtCreateStream(&stream_));
        initialized_ = true;
        return true;
    }

    void shutdown() {
        if (!initialized_) return;
        if (stream_) { aclrtDestroyStream(stream_); stream_ = nullptr; }
        if (ctx_)    { aclrtDestroyContext(ctx_); ctx_ = nullptr; }
        aclrtResetDevice(device_id_);
        aclFinalize();
        initialized_ = false;
    }

    void sync() { if (stream_) ACL_CHECK(aclrtSynchronizeStream(stream_)); }

    aclrtStream stream() const { return stream_; }
    int device_id() const { return device_id_; }

private:
    bool initialized_ = false;
    int device_id_ = 0;
    aclrtContext ctx_ = nullptr;
    aclrtStream  stream_ = nullptr;
};