File size: 6,211 Bytes
4b9fefd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
#include "tokenizer.h"

#include <array>
#include <cstdint>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <fstream>
#include <memory>
#include <sstream>
#include <unistd.h>

bool Tokenizer::load(const std::string& vocab_bin_path) {
    std::ifstream f(vocab_bin_path, std::ios::binary);
    if (!f) {
        fprintf(stderr, "Tokenizer: cannot open %s\n", vocab_bin_path.c_str());
        return false;
    }
    uint32_t num;
    f.read((char*)&num, 4);
    if (!f) return false;
    id_to_bytes_.resize(num);
    for (uint32_t i = 0; i < num; i++) {
        uint32_t len;
        f.read((char*)&len, 4);
        if (!f) return false;
        id_to_bytes_[i].resize(len);
        if (len > 0) f.read(id_to_bytes_[i].data(), len);
    }
    return true;
}

std::string Tokenizer::decode(int id) const {
    if (id < 0 || (size_t)id >= id_to_bytes_.size()) return "";
    return id_to_bytes_[id];
}

std::string Tokenizer::decode(const std::vector<int>& ids) const {
    std::string out;
    for (int id : ids) out += decode(id);
    return out;
}

std::vector<int> Tokenizer::encode_via_python(const std::string& model_dir,
                                              const std::string& prompt,
                                              bool apply_chat_template) const {
    // Call python subprocess to tokenize. Embed prompt via stdin to avoid shell-escape bugs.
    std::string cmd;
    // Set QWEN3_PYENV_INIT to override the Python env activation sequence (e.g., "source /opt/my_env/activate && ")
    // Default assumes conda at ~/miniconda3 with env 'qwen3' and Ascend toolkit installed.
    if (const char* init = std::getenv("QWEN3_PYENV_INIT")) {
        cmd += init;
    } else {
        cmd += "source ${HOME}/miniconda3/etc/profile.d/conda.sh 2>/dev/null && ";
        cmd += "conda activate qwen3 2>/dev/null || true; ";
        cmd += "source /usr/local/Ascend/ascend-toolkit/set_env.sh 2>/dev/null || true; ";
    }
    cmd += "python3 -c \"";
    cmd += "import sys, json;";
    cmd += "from transformers import AutoTokenizer;";
    cmd += "t = AutoTokenizer.from_pretrained('" + model_dir + "');";
    cmd += "p = sys.stdin.read();";
    if (apply_chat_template) {
        cmd += "msg = [{'role': 'user', 'content': p}];";
        cmd += "ids = t.apply_chat_template(msg, add_generation_prompt=True);";
    } else {
        cmd += "ids = t.encode(p);";
    }
    cmd += "print(' '.join(str(i) for i in ids));";
    cmd += "\"";

    // popen with stdin: use the two-pipe dance via temp file for safety
    char tmpl[] = "/tmp/lca_prompt_XXXXXX";
    int fd = mkstemp(tmpl);
    if (fd < 0) { perror("mkstemp"); return {}; }
    write(fd, prompt.data(), prompt.size());
    close(fd);

    std::string full = cmd + " < " + tmpl + " 2>/dev/null";
    FILE* pipe = popen(full.c_str(), "r");
    if (!pipe) { perror("popen"); unlink(tmpl); return {}; }

    std::string out;
    char buf[4096];
    while (size_t n = fread(buf, 1, sizeof(buf), pipe)) out.append(buf, n);
    pclose(pipe);
    unlink(tmpl);

    std::vector<int> ids;
    std::istringstream iss(out);
    int x;
    while (iss >> x) ids.push_back(x);
    return ids;
}

// Shell-quote a string for embedding in a JSON string (escape ", \, control chars).
static std::string json_escape(const std::string& s) {
    std::string out;
    out.reserve(s.size() + 8);
    for (char c : s) {
        switch (c) {
            case '"':  out += "\\\""; break;
            case '\\': out += "\\\\"; break;
            case '\n': out += "\\n";  break;
            case '\r': out += "\\r";  break;
            case '\t': out += "\\t";  break;
            default:
                if ((unsigned char)c < 0x20) {
                    char buf[8];
                    snprintf(buf, sizeof(buf), "\\u%04x", (unsigned char)c);
                    out += buf;
                } else {
                    out += c;
                }
        }
    }
    return out;
}

std::vector<int> Tokenizer::encode_conversation_via_python(
    const std::string& model_dir,
    const std::vector<std::pair<std::string, std::string>>& conversation,
    bool add_generation_prompt) const
{
    // Build JSON array of messages. Pass via stdin to avoid shell-escape issues.
    std::string json_msgs = "[";
    for (size_t i = 0; i < conversation.size(); i++) {
        if (i > 0) json_msgs += ",";
        json_msgs += "{\"role\":\"" + json_escape(conversation[i].first) + "\",";
        json_msgs += "\"content\":\"" + json_escape(conversation[i].second) + "\"}";
    }
    json_msgs += "]";

    std::string cmd;
    // Set QWEN3_PYENV_INIT to override the Python env activation sequence (e.g., "source /opt/my_env/activate && ")
    // Default assumes conda at ~/miniconda3 with env 'qwen3' and Ascend toolkit installed.
    if (const char* init = std::getenv("QWEN3_PYENV_INIT")) {
        cmd += init;
    } else {
        cmd += "source ${HOME}/miniconda3/etc/profile.d/conda.sh 2>/dev/null && ";
        cmd += "conda activate qwen3 2>/dev/null || true; ";
        cmd += "source /usr/local/Ascend/ascend-toolkit/set_env.sh 2>/dev/null || true; ";
    }
    cmd += "python3 -c \"";
    cmd += "import sys, json;";
    cmd += "from transformers import AutoTokenizer;";
    cmd += "t = AutoTokenizer.from_pretrained('" + model_dir + "');";
    cmd += "msgs = json.loads(sys.stdin.read());";
    cmd += "ids = t.apply_chat_template(msgs, add_generation_prompt=";
    cmd += add_generation_prompt ? "True" : "False";
    cmd += ");";
    cmd += "print(' '.join(str(i) for i in ids));";
    cmd += "\"";

    char tmpl[] = "/tmp/lca_conv_XXXXXX";
    int fd = mkstemp(tmpl);
    if (fd < 0) { perror("mkstemp"); return {}; }
    write(fd, json_msgs.data(), json_msgs.size());
    close(fd);

    std::string full = cmd + " < " + tmpl + " 2>/dev/null";
    FILE* pipe = popen(full.c_str(), "r");
    if (!pipe) { perror("popen"); unlink(tmpl); return {}; }

    std::string out;
    char buf[4096];
    while (size_t n = fread(buf, 1, sizeof(buf), pipe)) out.append(buf, n);
    pclose(pipe);
    unlink(tmpl);

    std::vector<int> ids;
    std::istringstream iss(out);
    int x;
    while (iss >> x) ids.push_back(x);
    return ids;
}