| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| #include "runner.h" |
| #include "tokenizer.h" |
|
|
| |
| HcclCtx* runner_hccl_ctx_shim(Runner& r); |
|
|
| #include <algorithm> |
| #include <chrono> |
| #include <cmath> |
| #include <cstdio> |
| #include <cstdlib> |
| #include <cstring> |
| #include <iostream> |
| #include <random> |
| #include <string> |
| #include <vector> |
|
|
| static float bf16_to_float(uint16_t x) { |
| uint32_t u = (uint32_t)x << 16; float f; std::memcpy(&f, &u, 4); return f; |
| } |
|
|
| |
| |
| |
| static std::string utf8_trim_incomplete(const std::string& s) { |
| if (s.empty()) return s; |
| size_t n = s.size(); |
| |
| for (size_t back = 0; back < 4 && back < n; back++) { |
| size_t i = n - 1 - back; |
| unsigned char c = (unsigned char)s[i]; |
| if ((c & 0x80) == 0) { return s; } |
| if ((c & 0xC0) == 0x80) { continue; } |
| |
| size_t need = 0; |
| if ((c & 0xE0) == 0xC0) need = 2; |
| else if ((c & 0xF0) == 0xE0) need = 3; |
| else if ((c & 0xF8) == 0xF0) need = 4; |
| else return s.substr(0, i); |
| size_t have = back + 1; |
| return (have >= need) ? s : s.substr(0, i); |
| } |
| |
| return s; |
| } |
|
|
| struct Args { |
| std::string model_dir; |
| std::string prompt = "The capital of France is"; |
| std::string vocab_path = "tokenizer_data/vocab.bin"; |
| int n_predict = 100; |
| int tp_size = 1; |
| int tp_rank = 0; |
| int num_layers = 0; |
| int max_seq = 512; |
| int device_id = 0; |
| bool chat_template = false; |
| bool stream = true; |
| bool interactive = false; |
| bool reset_each_turn = false; |
| std::string system_prompt; |
| std::string prompt_file; |
| bool pld_enabled = false; |
| int pld_k = 10; |
| int pld_ngram = 1; |
| bool pld_adaptive = false; |
| int pld_min_hist = 20; |
| |
| bool pld_guard = true; |
| int pld_guard_distinct = 3; |
| int pld_guard_tail = 6; |
| int pld_loop_warn = 8; |
| float temperature = 0.0f; |
| int top_k = 0; |
| float top_p = 1.0f; |
| uint64_t seed = 0; |
| |
| std::vector<int> eos_ids = {151645, 151643}; |
| }; |
|
|
| static bool parse_args(int argc, char** argv, Args& a) { |
| for (int i = 1; i < argc; i++) { |
| std::string s = argv[i]; |
| auto next = [&](const char* f)->const char* { |
| if (i + 1 >= argc) { fprintf(stderr, "missing value for %s\n", f); return nullptr; } |
| return argv[++i]; |
| }; |
| if (s == "--model-dir") { auto v = next(s.c_str()); if (!v) return false; a.model_dir = v; } |
| else if (s == "--prompt") { auto v = next(s.c_str()); if (!v) return false; a.prompt = v; } |
| else if (s == "--vocab") { auto v = next(s.c_str()); if (!v) return false; a.vocab_path = v; } |
| else if (s == "--n-predict") { auto v = next(s.c_str()); if (!v) return false; a.n_predict = std::atoi(v); } |
| else if (s == "--tp-size") { auto v = next(s.c_str()); if (!v) return false; a.tp_size = std::atoi(v); } |
| else if (s == "--num-layers") { auto v = next(s.c_str()); if (!v) return false; a.num_layers = std::atoi(v); } |
| else if (s == "--max-seq") { auto v = next(s.c_str()); if (!v) return false; a.max_seq = std::atoi(v); } |
| else if (s == "--device") { auto v = next(s.c_str()); if (!v) return false; a.device_id = std::atoi(v); } |
| else if (s == "--temperature") { auto v = next(s.c_str()); if (!v) return false; a.temperature = (float)std::atof(v); } |
| else if (s == "--top-k") { auto v = next(s.c_str()); if (!v) return false; a.top_k = std::atoi(v); } |
| else if (s == "--top-p") { auto v = next(s.c_str()); if (!v) return false; a.top_p = (float)std::atof(v); } |
| else if (s == "--seed") { auto v = next(s.c_str()); if (!v) return false; a.seed = (uint64_t)std::atoll(v); } |
| else if (s == "--chat") { a.chat_template = true; } |
| else if (s == "--no-stream") { a.stream = false; } |
| else if (s == "--interactive" || s == "-i") { a.interactive = true; } |
| else if (s == "--reset") { a.reset_each_turn = true; } |
| else if (s == "--system") { auto v = next(s.c_str()); if (!v) return false; a.system_prompt = v; } |
| else if (s == "--prompt-file") { auto v = next(s.c_str()); if (!v) return false; a.prompt_file = v; } |
| else if (s == "--pld") { a.pld_enabled = true; } |
| else if (s == "--pld-k") { auto v = next(s.c_str()); if (!v) return false; a.pld_k = std::atoi(v); } |
| else if (s == "--pld-ngram") { auto v = next(s.c_str()); if (!v) return false; a.pld_ngram = std::atoi(v); } |
| else if (s == "--pld-adaptive"){ a.pld_adaptive = true; } |
| else if (s == "--pld-fixed-k") { a.pld_adaptive = false; } |
| else if (s == "--pld-min-hist"){ auto v = next(s.c_str()); if (!v) return false; a.pld_min_hist = std::atoi(v); } |
| else if (s == "--pld-no-guard"){ a.pld_guard = false; } |
| else if (s == "--pld-guard-distinct"){ auto v = next(s.c_str()); if (!v) return false; a.pld_guard_distinct = std::atoi(v); } |
| else if (s == "--pld-guard-tail"){ auto v = next(s.c_str()); if (!v) return false; a.pld_guard_tail = std::atoi(v); } |
| else if (s == "--pld-loop-warn"){ auto v = next(s.c_str()); if (!v) return false; a.pld_loop_warn = std::atoi(v); } |
| else if (s == "--help" || s == "-h") { |
| printf("Usage: %s --model-dir <path> [options]\n", argv[0]); |
| printf(" --prompt \"text\" prompt text (default: \"%s\")\n", a.prompt.c_str()); |
| printf(" --prompt-file FILE read prompt from file (overrides --prompt)\n"); |
| printf(" --n-predict N max tokens to generate (default: %d)\n", a.n_predict); |
| printf(" --tp-size N tensor parallelism (default: 1; or TP_SIZE env)\n"); |
| printf(" --num-layers N limit layers, testing only (default: all)\n"); |
| printf(" --max-seq N KV cache + context cap (default: %d)\n", a.max_seq); |
| printf(" --chat apply Qwen3 chat template\n"); |
| printf(" --system \"text\" system role for chat\n"); |
| printf(" --temperature F 0 = greedy; typical 0.7\n"); |
| printf(" --top-k N 0 = disabled\n"); |
| printf(" --top-p F 1.0 = disabled; typical 0.8\n"); |
| printf(" --seed N 0 = time-based (default)\n"); |
| printf(" --no-stream batch-print final text\n"); |
| printf(" -i, --interactive REPL (multi-turn memory when --chat)\n"); |
| printf(" --reset force stateless REPL (reset each turn)\n"); |
| printf(" --pld enable Prompt Lookup Decoding (greedy only)\n"); |
| printf(" --pld-k N draft window size (default: 4)\n"); |
| printf(" --pld-ngram N match n-gram size (default: 2; multi-level fallback)\n"); |
| printf(" --pld-adaptive adjust K based on recent accept rate\n"); |
| printf(" --pld-min-hist N skip PLD until history >= N tokens (default: 20)\n"); |
| printf(" --pld-no-guard disable degeneration guard (dangerous: can amplify loops)\n"); |
| printf(" --pld-guard-distinct N reject draft with distinct tokens < N (default: 3)\n"); |
| printf(" --pld-guard-tail N reject draft if draft[0] matches all last N hist (default: 6)\n"); |
| printf(" --pld-loop-warn N warn once on N consecutive identical emitted tokens (default: 8)\n"); |
| return false; |
| } |
| else { fprintf(stderr, "unknown arg: %s\n", s.c_str()); return false; } |
| } |
| if (a.model_dir.empty()) { fprintf(stderr, "--model-dir required\n"); return false; } |
| if (const char* r = std::getenv("TP_RANK")) a.tp_rank = std::atoi(r); |
| if (const char* s = std::getenv("TP_SIZE")) a.tp_size = std::atoi(s); |
| return true; |
| } |
|
|
| |
| static int sample_token(const std::vector<uint16_t>& logits_bf16, int64_t V, |
| float temperature, int top_k, float top_p, std::mt19937& rng) { |
| if (temperature <= 0.0f) { |
| int best = 0; |
| float bv = bf16_to_float(logits_bf16[0]); |
| for (int64_t i = 1; i < V; i++) { |
| float v = bf16_to_float(logits_bf16[i]); |
| if (v > bv) { bv = v; best = (int)i; } |
| } |
| return best; |
| } |
|
|
| |
| std::vector<std::pair<float, int>> scored; |
| scored.reserve(V); |
| for (int64_t i = 0; i < V; i++) { |
| scored.emplace_back(bf16_to_float(logits_bf16[i]) / temperature, (int)i); |
| } |
|
|
| |
| if (top_k > 0 && top_k < (int)scored.size()) { |
| std::nth_element(scored.begin(), scored.begin() + top_k, scored.end(), |
| [](const auto& a, const auto& b){ return a.first > b.first; }); |
| scored.resize(top_k); |
| } |
|
|
| |
| std::sort(scored.begin(), scored.end(), |
| [](const auto& a, const auto& b){ return a.first > b.first; }); |
|
|
| |
| float maxv = scored[0].first; |
| double sum = 0; |
| for (auto& p : scored) { p.first = std::exp(p.first - maxv); sum += p.first; } |
| for (auto& p : scored) p.first /= (float)sum; |
|
|
| |
| if (top_p > 0.0f && top_p < 1.0f) { |
| double cum = 0; |
| size_t cutoff = scored.size(); |
| for (size_t i = 0; i < scored.size(); i++) { |
| cum += scored[i].first; |
| if (cum >= top_p) { cutoff = i + 1; break; } |
| } |
| scored.resize(cutoff); |
| |
| double s = 0; for (auto& p : scored) s += p.first; |
| for (auto& p : scored) p.first /= (float)s; |
| } |
|
|
| |
| std::uniform_real_distribution<float> U(0.0f, 1.0f); |
| float r = U(rng), acc = 0.0f; |
| for (auto& p : scored) { |
| acc += p.first; |
| if (r <= acc) return p.second; |
| } |
| return scored.back().second; |
| } |
|
|
| |
| |
| |
| static bool broadcast_token_ids(Runner& runner, std::vector<int32_t>& ids, |
| int64_t max_seq, bool is_master) { |
| const ModelConfig& cfg = runner.cfg(); |
| if (cfg.tp_size <= 1) return true; |
|
|
| |
| DeviceBuffer cnt_dev(4); |
| int32_t cnt = is_master ? (int32_t)ids.size() : 0; |
| ACL_CHECK(aclrtMemcpy(cnt_dev.get(), 4, &cnt, 4, ACL_MEMCPY_HOST_TO_DEVICE)); |
| |
| |
| |
| |
| |
| extern HcclCtx* runner_hccl_ctx_shim(Runner& r); |
| HcclCtx* ctx = runner_hccl_ctx_shim(runner); |
| if (!ctx) return false; |
|
|
| if (!hccl_broadcast(*ctx, cnt_dev.get(), 1, HCCL_DATA_TYPE_INT32, 0, runner.stream())) return false; |
| ACL_CHECK(aclrtSynchronizeStream(runner.stream())); |
| ACL_CHECK(aclrtMemcpy(&cnt, 4, cnt_dev.get(), 4, ACL_MEMCPY_DEVICE_TO_HOST)); |
| if (cnt <= 0 || cnt > (int32_t)max_seq) { |
| fprintf(stderr, "[rank %d] broadcast: bad count %d\n", cfg.tp_rank, cnt); |
| return false; |
| } |
|
|
| |
| DeviceBuffer ids_dev(cnt * 4); |
| if (is_master) { |
| ACL_CHECK(aclrtMemcpy(ids_dev.get(), cnt*4, ids.data(), cnt*4, ACL_MEMCPY_HOST_TO_DEVICE)); |
| } |
| if (!hccl_broadcast(*ctx, ids_dev.get(), cnt, HCCL_DATA_TYPE_INT32, 0, runner.stream())) return false; |
| ACL_CHECK(aclrtSynchronizeStream(runner.stream())); |
| if (!is_master) { |
| ids.resize(cnt); |
| ACL_CHECK(aclrtMemcpy(ids.data(), cnt*4, ids_dev.get(), cnt*4, ACL_MEMCPY_DEVICE_TO_HOST)); |
| } |
| return true; |
| } |
|
|
| |
| struct TurnStats { |
| double prefill_ms = 0; double decode_ms = 0; |
| int n_prompt = 0; int decoded = 0; bool hit_eos = false; |
| }; |
| static TurnStats run_turn(Runner& runner, Tokenizer& tokenizer, const Args& args, |
| const std::string& prompt, std::mt19937& rng, bool is_master) { |
| TurnStats st; |
|
|
| |
| std::vector<int32_t> input_ids; |
| if (is_master) { |
| auto raw = tokenizer.encode_via_python(args.model_dir, prompt, args.chat_template); |
| if (raw.empty()) return st; |
| input_ids.reserve(raw.size()); |
| for (int v : raw) input_ids.push_back((int32_t)v); |
| } |
| if (args.tp_size > 1) { |
| if (!broadcast_token_ids(runner, input_ids, args.max_seq, is_master)) return st; |
| } |
| if (input_ids.empty()) return st; |
|
|
| const int64_t V = runner.cfg().vocab_size; |
| std::vector<uint16_t> logits_h(V); |
| auto load_logits = [&](DeviceBuffer& buf) { |
| ACL_CHECK(aclrtMemcpy(logits_h.data(), V*2, buf.get(), V*2, ACL_MEMCPY_DEVICE_TO_HOST)); |
| }; |
| auto is_eos = [&](int id) { |
| for (int e : args.eos_ids) if (id == e) return true; |
| return false; |
| }; |
|
|
| |
| st.n_prompt = (int)input_ids.size(); |
| auto t0 = std::chrono::steady_clock::now(); |
| DeviceBuffer logits; |
| if (!runner.prefill(input_ids.data(), (int64_t)input_ids.size(), logits)) return st; |
| auto t1 = std::chrono::steady_clock::now(); |
| st.prefill_ms = std::chrono::duration<double, std::milli>(t1 - t0).count(); |
|
|
| load_logits(logits); |
| int next_id = sample_token(logits_h, V, args.temperature, args.top_k, args.top_p, rng); |
|
|
| if (is_master && args.stream) { |
| if (!args.chat_template) printf("%s", prompt.c_str()); |
| printf("%s", tokenizer.decode(next_id).c_str()); |
| fflush(stdout); |
| } |
|
|
| std::vector<int> generated = { next_id }; |
| st.hit_eos = is_eos(next_id); |
|
|
| |
| |
| std::vector<int32_t> hist; |
| hist.reserve(input_ids.size() + args.n_predict + 16); |
| for (auto x : input_ids) hist.push_back(x); |
| hist.push_back(next_id); |
|
|
| |
| |
| auto lookup_one = [&](int ngram, int K) -> std::vector<int32_t> { |
| int hs = (int)hist.size(); |
| if (hs < ngram + 1 || K <= 0) return {}; |
| for (int start = hs - ngram - 1; start >= 0; start--) { |
| bool match = true; |
| for (int k = 0; k < ngram; k++) { |
| if (hist[start + k] != hist[hs - ngram + k]) { match = false; break; } |
| } |
| if (match) { |
| int after = start + ngram; |
| std::vector<int32_t> d; |
| for (int k = 0; k < K && after + k < hs; k++) { |
| d.push_back(hist[after + k]); |
| if (is_eos(hist[after + k])) break; |
| } |
| if (!d.empty()) return d; |
| } |
| } |
| return {}; |
| }; |
| |
| auto lookup_draft = [&](int ngram, int K) -> std::vector<int32_t> { |
| for (int n = ngram; n >= 1; n--) { |
| auto d = lookup_one(n, K); |
| if (!d.empty()) return d; |
| } |
| return {}; |
| }; |
|
|
| |
| |
| |
| auto draft_degenerate = [&](const std::vector<int32_t>& d) -> const char* { |
| if (!args.pld_guard || d.empty()) return nullptr; |
| |
| |
| if ((int)d.size() >= 3) { |
| int distinct = 0; |
| for (int i = 0; i < (int)d.size(); i++) { |
| bool seen = false; |
| for (int j = 0; j < i; j++) { if (d[j] == d[i]) { seen = true; break; } } |
| if (!seen) distinct++; |
| } |
| if (distinct < args.pld_guard_distinct) return "low-distinct"; |
| } |
| |
| |
| int tail_n = std::min(args.pld_guard_tail, (int)hist.size()); |
| if (tail_n >= 3) { |
| int matches = 0; |
| for (int i = (int)hist.size() - tail_n; i < (int)hist.size(); i++) { |
| if (hist[i] == d[0]) matches++; |
| } |
| if (matches == tail_n) return "tail-echo"; |
| } |
| return nullptr; |
| }; |
|
|
| |
| auto t2 = std::chrono::steady_clock::now(); |
| int pld_verifies = 0, pld_accepted = 0; |
| int pld_rej_lowdist = 0, pld_rej_tailecho = 0; |
| bool loop_warned = false; |
|
|
| |
| const int ADAPT_WINDOW = 8; |
| std::vector<int> recent_accepts; |
| int current_k = args.pld_k; |
| bool pld_disabled_adapt = false; |
|
|
| while (st.decoded < args.n_predict - 1 && !st.hit_eos) { |
| |
| |
| |
| if (args.pld_adaptive && (int)recent_accepts.size() >= ADAPT_WINDOW) { |
| double avg = 0; |
| for (int a : recent_accepts) avg += a; |
| avg /= recent_accepts.size(); |
| |
| current_k = std::max(4, std::min(12, (int)std::round(2.0 * avg + 4.0))); |
| } |
|
|
| |
| std::vector<int32_t> draft; |
| if (args.pld_enabled && (int)hist.size() >= args.pld_min_hist && is_master) { |
| draft = lookup_draft(args.pld_ngram, current_k); |
| |
| |
| |
| |
| if (!draft.empty()) { |
| const char* reason = draft_degenerate(draft); |
| if (reason) { |
| if (reason[0] == 'l') pld_rej_lowdist++; |
| else pld_rej_tailecho++; |
| draft.clear(); |
| } |
| } |
| } |
| |
| |
| bool has_draft = is_master ? !draft.empty() : false; |
| |
| if (args.tp_size > 1) { |
| extern HcclCtx* runner_hccl_ctx_shim(Runner&); |
| HcclCtx* ctx = runner_hccl_ctx_shim(runner); |
| DeviceBuffer flag(4); |
| int32_t f = has_draft ? 1 : 0; |
| ACL_CHECK(aclrtMemcpy(flag.get(), 4, &f, 4, ACL_MEMCPY_HOST_TO_DEVICE)); |
| hccl_broadcast(*ctx, flag.get(), 1, HCCL_DATA_TYPE_INT32, 0, runner.stream()); |
| ACL_CHECK(aclrtSynchronizeStream(runner.stream())); |
| ACL_CHECK(aclrtMemcpy(&f, 4, flag.get(), 4, ACL_MEMCPY_DEVICE_TO_HOST)); |
| has_draft = (f != 0); |
| if (has_draft) { |
| std::vector<int32_t> d = draft; |
| broadcast_token_ids(runner, d, args.max_seq, is_master); |
| draft = d; |
| } else { |
| draft.clear(); |
| } |
| } |
|
|
| if (args.pld_enabled && (int)draft.size() >= 1 && args.temperature == 0.0f) { |
| |
| std::vector<int32_t> batch_input = { next_id }; |
| for (auto d : draft) batch_input.push_back(d); |
| int S = (int)batch_input.size(); |
| DeviceBuffer batch_logits; |
| if (!runner.decode_batch(batch_input.data(), S, batch_logits)) break; |
|
|
| std::vector<uint16_t> blh(S * V); |
| if (is_master) ACL_CHECK(aclrtMemcpy(blh.data(), S*V*2, batch_logits.get(), S*V*2, ACL_MEMCPY_DEVICE_TO_HOST)); |
|
|
| |
| |
| |
| int accept = 0, new_next = next_id; |
| if (is_master) { |
| for (int i = 0; i < S - 1; i++) { |
| int pred = 0; float bv = bf16_to_float(blh[i * V]); |
| for (int k = 1; k < V; k++) { float v = bf16_to_float(blh[i*V + k]); if (v > bv) { bv = v; pred = k; } } |
| if (pred == (int)draft[i]) accept++; |
| else { new_next = pred; break; } |
| } |
| if (accept == S - 1) { |
| |
| int pred = 0; float bv = bf16_to_float(blh[(S-1) * V]); |
| for (int k = 1; k < V; k++) { float v = bf16_to_float(blh[(S-1)*V + k]); if (v > bv) { bv = v; pred = k; } } |
| new_next = pred; |
| } |
| } |
|
|
| |
| if (args.tp_size > 1) { |
| int32_t packed[2] = { (int32_t)accept, (int32_t)new_next }; |
| std::vector<int32_t> p(packed, packed + 2); |
| broadcast_token_ids(runner, p, args.max_seq, is_master); |
| if (p.size() == 2) { accept = p[0]; new_next = p[1]; } |
| } |
|
|
| |
| int64_t rewind = (int64_t)(S - 1 - accept); |
| if (rewind > 0) runner.rewind_cache(rewind); |
|
|
| |
| for (int i = 0; i < accept; i++) { |
| int tok = (int)draft[i]; |
| hist.push_back(tok); |
| generated.push_back(tok); |
| st.decoded++; |
| if (is_master && args.stream) { printf("%s", tokenizer.decode(tok).c_str()); fflush(stdout); } |
| if (is_eos(tok)) { st.hit_eos = true; break; } |
| } |
| pld_verifies++; pld_accepted += accept; |
| |
| if (args.pld_adaptive) { |
| recent_accepts.push_back(accept); |
| if ((int)recent_accepts.size() > ADAPT_WINDOW) recent_accepts.erase(recent_accepts.begin()); |
| } |
| if (st.hit_eos) break; |
|
|
| |
| hist.push_back(new_next); |
| generated.push_back(new_next); |
| st.decoded++; |
| if (is_master && args.stream) { printf("%s", tokenizer.decode(new_next).c_str()); fflush(stdout); } |
| if (is_eos(new_next)) { st.hit_eos = true; break; } |
| next_id = new_next; |
| } else { |
| |
| DeviceBuffer logits2; |
| if (!runner.decode((int32_t)next_id, logits2)) break; |
| load_logits(logits2); |
| next_id = sample_token(logits_h, V, args.temperature, args.top_k, args.top_p, rng); |
| hist.push_back(next_id); |
| generated.push_back(next_id); |
| st.decoded++; |
| if (is_master && args.stream) { printf("%s", tokenizer.decode(next_id).c_str()); fflush(stdout); } |
| if (is_eos(next_id)) { st.hit_eos = true; break; } |
| } |
| |
| |
| if (is_master && !loop_warned && args.pld_loop_warn > 0 && |
| (int)generated.size() >= args.pld_loop_warn) { |
| int tail = args.pld_loop_warn; |
| int anchor = generated[(int)generated.size() - tail]; |
| bool all_same = true; |
| for (int i = (int)generated.size() - tail + 1; i < (int)generated.size(); i++) { |
| if (generated[i] != anchor) { all_same = false; break; } |
| } |
| if (all_same) { |
| fprintf(stderr, "\n[warn] %d consecutive identical tokens — likely degeneration loop; output after this point is suspect\n", tail); |
| loop_warned = true; |
| } |
| } |
| } |
| auto t3 = std::chrono::steady_clock::now(); |
| st.decode_ms = std::chrono::duration<double, std::milli>(t3 - t2).count(); |
| if (is_master && args.pld_enabled) { |
| if (pld_verifies > 0) { |
| fprintf(stderr, "\n[pld] %d verifies, %d drafts accepted, avg=%.2f", |
| pld_verifies, pld_accepted, (double)pld_accepted / pld_verifies); |
| } else { |
| fprintf(stderr, "\n[pld] 0 verifies (all drafts blocked or none found)"); |
| } |
| if (args.pld_guard && (pld_rej_lowdist + pld_rej_tailecho) > 0) { |
| fprintf(stderr, "; guard rejections: low-distinct=%d tail-echo=%d", |
| pld_rej_lowdist, pld_rej_tailecho); |
| } |
| fprintf(stderr, "\n"); |
| } |
|
|
| if (is_master) { |
| if (args.stream) { printf("\n"); fflush(stdout); } |
| else { |
| std::string text = tokenizer.decode(generated); |
| printf("%s%s\n", args.chat_template ? "" : prompt.c_str(), text.c_str()); |
| } |
| } |
| return st; |
| } |
|
|
| static bool load_file(const std::string& path, std::string& out) { |
| FILE* f = fopen(path.c_str(), "rb"); |
| if (!f) { fprintf(stderr, "[cli] cannot open %s\n", path.c_str()); return false; } |
| fseek(f, 0, SEEK_END); long sz = ftell(f); fseek(f, 0, SEEK_SET); |
| out.resize(sz); |
| size_t n = fread(out.data(), 1, sz, f); |
| fclose(f); |
| if ((long)n != sz) { fprintf(stderr, "[cli] short read from %s\n", path.c_str()); return false; } |
| |
| if (!out.empty() && out.back() == '\n') out.pop_back(); |
| return true; |
| } |
|
|
| int main(int argc, char** argv) { |
| Args args; |
| if (!parse_args(argc, argv, args)) return 1; |
|
|
| |
| if (!args.prompt_file.empty()) { |
| if (!load_file(args.prompt_file, args.prompt)) return 1; |
| } |
|
|
| const bool is_master = (args.tp_rank == 0); |
| std::mt19937 rng(args.seed ? args.seed : |
| (uint64_t)std::chrono::steady_clock::now().time_since_epoch().count()); |
|
|
| if (is_master) { |
| printf("[cli] model=%s\n", args.model_dir.c_str()); |
| printf("[cli] tp=%d n_predict=%d temp=%.2f top_k=%d top_p=%.2f chat=%d interactive=%d\n", |
| args.tp_size, args.n_predict, args.temperature, args.top_k, args.top_p, |
| args.chat_template, args.interactive); |
| fflush(stdout); |
| } |
|
|
| Tokenizer tokenizer; |
| if (!tokenizer.load(args.vocab_path)) { |
| fprintf(stderr, "[cli] failed to load vocab %s\n", args.vocab_path.c_str()); return 1; |
| } |
|
|
| Runner runner; |
| int num_layers = args.num_layers; |
| if (num_layers == 0) { |
| ModelConfig probe; |
| if (!probe.load_from_json(args.model_dir + "/config.json")) return 1; |
| num_layers = (int)probe.num_hidden_layers; |
| } |
| if (!runner.init(args.model_dir, args.tp_size, args.tp_rank, |
| num_layers, args.max_seq, args.device_id)) return 1; |
| if (const char* p = std::getenv("LCA_PROFILE"); p && std::atoi(p) != 0) { |
| runner.profile_enabled = true; |
| } |
| |
| if (const char* w = std::getenv("LCA_WARMUP"); w) { |
| int n = std::atoi(w); |
| if (n > 0) runner.warmup(n); |
| } |
|
|
| if (args.interactive) { |
| const bool multi_turn = args.chat_template && !args.reset_each_turn; |
| if (is_master) { |
| printf("\n[cli] === interactive mode ===\n"); |
| if (multi_turn) { |
| printf("[cli] multi-turn chat (KV cache preserved). Commands: 'quit', 'reset'.\n"); |
| if (!args.system_prompt.empty()) { |
| printf("[cli] system: %s\n", args.system_prompt.c_str()); |
| } |
| } else { |
| printf("[cli] stateless mode (KV cache reset each turn). Command: 'quit'.\n"); |
| if (!args.chat_template) { |
| printf("[cli] (hint: add --chat for multi-turn conversational memory)\n"); |
| } |
| } |
| fflush(stdout); |
| } |
|
|
| |
| std::vector<std::pair<std::string, std::string>> conversation; |
| if (multi_turn && !args.system_prompt.empty()) { |
| conversation.emplace_back("system", args.system_prompt); |
| } |
|
|
| auto* hccl_ctx = runner_hccl_ctx_shim(runner); |
|
|
| |
| auto broadcast_signal = [&](int32_t sig)->int32_t { |
| if (args.tp_size <= 1) return sig; |
| DeviceBuffer s(4); |
| ACL_CHECK(aclrtMemcpy(s.get(), 4, &sig, 4, ACL_MEMCPY_HOST_TO_DEVICE)); |
| hccl_broadcast(*hccl_ctx, s.get(), 1, HCCL_DATA_TYPE_INT32, 0, runner.stream()); |
| ACL_CHECK(aclrtSynchronizeStream(runner.stream())); |
| int32_t r; ACL_CHECK(aclrtMemcpy(&r, 4, s.get(), 4, ACL_MEMCPY_DEVICE_TO_HOST)); |
| return r; |
| }; |
|
|
| while (true) { |
| std::string prompt; |
| int32_t sig = 0; |
| if (is_master) { |
| printf("\n> "); fflush(stdout); |
| if (!std::getline(std::cin, prompt)) sig = 1; |
| else if (prompt == "quit" || prompt == "exit") sig = 1; |
| else if (prompt == "reset") sig = 2; |
| else if (prompt.empty()) sig = 3; |
| } |
| sig = broadcast_signal(sig); |
| if (sig == 1) break; |
| if (sig == 2) { |
| runner.reset_cache(); |
| conversation.clear(); |
| if (multi_turn && !args.system_prompt.empty()) |
| conversation.emplace_back("system", args.system_prompt); |
| if (is_master) { printf("[cli] (cache + conversation reset)\n"); fflush(stdout); } |
| continue; |
| } |
| if (sig == 3) continue; |
|
|
| TurnStats st; |
| if (multi_turn) { |
| |
| if (is_master) conversation.emplace_back("user", prompt); |
| |
| |
| std::vector<int32_t> full_ids; |
| if (is_master) { |
| auto raw = tokenizer.encode_conversation_via_python(args.model_dir, conversation, true); |
| full_ids.reserve(raw.size()); |
| for (int v : raw) full_ids.push_back((int32_t)v); |
| } |
| |
| if (args.tp_size > 1) { |
| if (!broadcast_token_ids(runner, full_ids, args.max_seq, is_master)) break; |
| } |
| if (full_ids.empty()) { if (is_master) printf("[cli] tokenize failed\n"); continue; } |
|
|
| int64_t past = runner.past_len(); |
| if ((int64_t)full_ids.size() < past) { runner.reset_cache(); past = 0; } |
| std::vector<int32_t> delta(full_ids.begin() + past, full_ids.end()); |
| if (delta.empty()) { |
| if (is_master) printf("[cli] (no new tokens)\n"); |
| continue; |
| } |
| |
| |
| |
| if ((int64_t)(past + delta.size()) + args.n_predict > args.max_seq) { |
| if (is_master) { |
| fprintf(stderr, "[cli] context %ld + gen %d > max_seq %d — auto-resetting\n", |
| (long)(past + delta.size()), args.n_predict, args.max_seq); |
| } |
| runner.reset_cache(); |
| |
| if (is_master) { |
| std::vector<std::pair<std::string, std::string>> fresh; |
| for (auto& m : conversation) if (m.first == "system") fresh.push_back(m); |
| if (!conversation.empty() && conversation.back().first == "user") { |
| fresh.push_back(conversation.back()); |
| } |
| conversation = std::move(fresh); |
| auto raw = tokenizer.encode_conversation_via_python(args.model_dir, conversation, true); |
| full_ids.clear(); |
| for (int v : raw) full_ids.push_back((int32_t)v); |
| } |
| if (args.tp_size > 1) { |
| if (!broadcast_token_ids(runner, full_ids, args.max_seq, is_master)) break; |
| } |
| delta.assign(full_ids.begin(), full_ids.end()); |
| past = 0; |
| } |
|
|
| |
| st.n_prompt = (int)delta.size(); |
| auto t0 = std::chrono::steady_clock::now(); |
| DeviceBuffer logits; |
| if (!runner.prefill(delta.data(), (int64_t)delta.size(), logits)) break; |
| auto t1 = std::chrono::steady_clock::now(); |
| st.prefill_ms = std::chrono::duration<double, std::milli>(t1 - t0).count(); |
|
|
| const int64_t V = runner.cfg().vocab_size; |
| std::vector<uint16_t> logits_h(V); |
| auto load_logits = [&](DeviceBuffer& buf) { |
| ACL_CHECK(aclrtMemcpy(logits_h.data(), V*2, buf.get(), V*2, ACL_MEMCPY_DEVICE_TO_HOST)); |
| }; |
| auto is_eos = [&](int id) { |
| for (int e : args.eos_ids) if (id == e) return true; |
| return false; |
| }; |
| load_logits(logits); |
| int next_id = sample_token(logits_h, V, args.temperature, args.top_k, args.top_p, rng); |
|
|
| std::vector<int> assistant_ids = { next_id }; |
| if (is_master) { printf("%s", tokenizer.decode(next_id).c_str()); fflush(stdout); } |
| st.hit_eos = is_eos(next_id); |
|
|
| auto t2 = std::chrono::steady_clock::now(); |
| for (int step = 1; step < args.n_predict && !st.hit_eos; step++) { |
| DeviceBuffer logits2; |
| if (!runner.decode((int32_t)next_id, logits2)) break; |
| load_logits(logits2); |
| next_id = sample_token(logits_h, V, args.temperature, args.top_k, args.top_p, rng); |
| assistant_ids.push_back(next_id); |
| st.decoded++; |
| if (is_master) { printf("%s", tokenizer.decode(next_id).c_str()); fflush(stdout); } |
| if (is_eos(next_id)) { st.hit_eos = true; break; } |
| } |
| auto t3 = std::chrono::steady_clock::now(); |
| st.decode_ms = std::chrono::duration<double, std::milli>(t3 - t2).count(); |
| if (is_master) { printf("\n"); fflush(stdout); } |
|
|
| |
| |
| if (is_master) { |
| std::vector<int> content_ids; |
| for (int id : assistant_ids) { if (is_eos(id)) break; content_ids.push_back(id); } |
| conversation.emplace_back("assistant", utf8_trim_incomplete(tokenizer.decode(content_ids))); |
| } |
| } else { |
| |
| runner.reset_cache(); |
| st = run_turn(runner, tokenizer, args, prompt, rng, is_master); |
| } |
|
|
| if (is_master) { |
| double tgs = (st.decode_ms > 0) ? (st.decoded * 1000.0 / st.decode_ms) : 0.0; |
| printf("[perf] prefill %d tok %.0fms decode %d tok %.0fms = %.2f t/s%s past_len=%ld\n", |
| st.n_prompt, st.prefill_ms, st.decoded, st.decode_ms, tgs, |
| st.hit_eos ? " (EOS)" : "", runner.past_len()); |
| fflush(stdout); |
| } |
| } |
| if (is_master) printf("[cli] bye\n"); |
| return 0; |
| } |
|
|
| |
| TurnStats st = run_turn(runner, tokenizer, args, args.prompt, rng, is_master); |
| if (is_master) runner.print_profile_summary(); |
| if (is_master) { |
| if (st.hit_eos) printf("[cli] (hit EOS)\n"); |
| printf("\n[perf] prefill: %.1fms for %d tokens = %.2f t/s\n", |
| st.prefill_ms, st.n_prompt, |
| (st.prefill_ms > 0) ? (st.n_prompt * 1000.0 / st.prefill_ms) : 0.0); |
| if (st.decoded > 0) { |
| printf("[perf] decode : %.1fms for %d tokens = %.2f t/s (TG)\n", |
| st.decode_ms, st.decoded, (st.decoded * 1000.0) / st.decode_ms); |
| } |
| } |
| return 0; |
| } |
|
|