llm_mutil_npu / src /main_cli.cpp
xianglarry's picture
Initial C++ aclnn EAGER inference for Qwen3-235B-A22B MoE on Ascend 910 × 16 NPU
4b9fefd
// main_cli.cpp — qwen3-moe-aclnn entry point.
//
// Usage:
// qwen3-moe-aclnn --model-dir <path> --prompt "<text>" --n-predict <N>
// [--tp-size 1|16] [--vocab <path>] [--max-seq N] [--num-layers N]
// [--chat] [--temperature 0.7] [--top-k 20] [--top-p 0.8] [--seed N]
// [--no-stream]
//
// At TP>1 each rank is a separate process (env TP_RANK=<i>, TP_SIZE=<n>) launched by
// scripts/tp_launch.sh. Only rank 0 prints text output.
#include "runner.h"
#include "tokenizer.h"
// Escape hatch for HCCL broadcast from within CLI (defined in runner.cpp)
HcclCtx* runner_hccl_ctx_shim(Runner& r);
#include <algorithm>
#include <chrono>
#include <cmath>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <random>
#include <string>
#include <vector>
static float bf16_to_float(uint16_t x) {
uint32_t u = (uint32_t)x << 16; float f; std::memcpy(&f, &u, 4); return f;
}
// Truncate a string to the last complete UTF-8 character boundary. If the last 1-3 bytes
// form an incomplete multi-byte sequence (e.g., assistant response cut mid-codepoint at
// n_predict limit), drop them so the JSON encoder downstream sees only valid UTF-8.
static std::string utf8_trim_incomplete(const std::string& s) {
if (s.empty()) return s;
size_t n = s.size();
// Walk back up to 4 bytes looking for the start of a UTF-8 sequence.
for (size_t back = 0; back < 4 && back < n; back++) {
size_t i = n - 1 - back;
unsigned char c = (unsigned char)s[i];
if ((c & 0x80) == 0) { return s; } // ASCII: already complete
if ((c & 0xC0) == 0x80) { continue; } // continuation byte: keep going
// Start byte: 110xxxxx (2-byte), 1110xxxx (3-byte), 11110xxx (4-byte)
size_t need = 0;
if ((c & 0xE0) == 0xC0) need = 2;
else if ((c & 0xF0) == 0xE0) need = 3;
else if ((c & 0xF8) == 0xF0) need = 4;
else return s.substr(0, i); // invalid start — drop
size_t have = back + 1;
return (have >= need) ? s : s.substr(0, i); // trim incomplete trailing sequence
}
// Should not reach here; return as-is.
return s;
}
struct Args {
std::string model_dir;
std::string prompt = "The capital of France is";
std::string vocab_path = "tokenizer_data/vocab.bin";
int n_predict = 100;
int tp_size = 1;
int tp_rank = 0;
int num_layers = 0; // 0 = auto
int max_seq = 512;
int device_id = 0;
bool chat_template = false;
bool stream = true;
bool interactive = false;
bool reset_each_turn = false; // if true, REPL clears KV cache between turns (stateless)
std::string system_prompt; // optional system role for chat mode
std::string prompt_file; // read prompt from file (avoids shell escaping)
bool pld_enabled = false; // prompt lookup decoding
int pld_k = 10; // bench_pld_k.sh: K=10 median 105 t/s (3/3 runs 100+), K=8 was 35
int pld_ngram = 1; // n-gram match size — 1 with multi-level fallback best
bool pld_adaptive = false; // fixed K=10 is simpler and mean-optimal; adaptive --pld-adaptive
int pld_min_hist = 20; // skip PLD until history >= this (avoid early-token false matches)
// PLD degeneration guard (on by default): prevents PLD from amplifying repetition loops.
bool pld_guard = true; // --pld-no-guard disables
int pld_guard_distinct = 3; // reject draft if distinct tokens < this (≥K/3 heuristic)
int pld_guard_tail = 6; // reject if draft[0] matches all last N hist tokens
int pld_loop_warn = 8; // warn once when N consecutive identical tokens emitted
float temperature = 0.0f; // 0 = greedy
int top_k = 0; // 0 = disabled
float top_p = 1.0f; // 1.0 = disabled
uint64_t seed = 0; // 0 = use time
// Qwen3 EOS tokens (from generation_config.json)
std::vector<int> eos_ids = {151645, 151643};
};
static bool parse_args(int argc, char** argv, Args& a) {
for (int i = 1; i < argc; i++) {
std::string s = argv[i];
auto next = [&](const char* f)->const char* {
if (i + 1 >= argc) { fprintf(stderr, "missing value for %s\n", f); return nullptr; }
return argv[++i];
};
if (s == "--model-dir") { auto v = next(s.c_str()); if (!v) return false; a.model_dir = v; }
else if (s == "--prompt") { auto v = next(s.c_str()); if (!v) return false; a.prompt = v; }
else if (s == "--vocab") { auto v = next(s.c_str()); if (!v) return false; a.vocab_path = v; }
else if (s == "--n-predict") { auto v = next(s.c_str()); if (!v) return false; a.n_predict = std::atoi(v); }
else if (s == "--tp-size") { auto v = next(s.c_str()); if (!v) return false; a.tp_size = std::atoi(v); }
else if (s == "--num-layers") { auto v = next(s.c_str()); if (!v) return false; a.num_layers = std::atoi(v); }
else if (s == "--max-seq") { auto v = next(s.c_str()); if (!v) return false; a.max_seq = std::atoi(v); }
else if (s == "--device") { auto v = next(s.c_str()); if (!v) return false; a.device_id = std::atoi(v); }
else if (s == "--temperature") { auto v = next(s.c_str()); if (!v) return false; a.temperature = (float)std::atof(v); }
else if (s == "--top-k") { auto v = next(s.c_str()); if (!v) return false; a.top_k = std::atoi(v); }
else if (s == "--top-p") { auto v = next(s.c_str()); if (!v) return false; a.top_p = (float)std::atof(v); }
else if (s == "--seed") { auto v = next(s.c_str()); if (!v) return false; a.seed = (uint64_t)std::atoll(v); }
else if (s == "--chat") { a.chat_template = true; }
else if (s == "--no-stream") { a.stream = false; }
else if (s == "--interactive" || s == "-i") { a.interactive = true; }
else if (s == "--reset") { a.reset_each_turn = true; }
else if (s == "--system") { auto v = next(s.c_str()); if (!v) return false; a.system_prompt = v; }
else if (s == "--prompt-file") { auto v = next(s.c_str()); if (!v) return false; a.prompt_file = v; }
else if (s == "--pld") { a.pld_enabled = true; }
else if (s == "--pld-k") { auto v = next(s.c_str()); if (!v) return false; a.pld_k = std::atoi(v); }
else if (s == "--pld-ngram") { auto v = next(s.c_str()); if (!v) return false; a.pld_ngram = std::atoi(v); }
else if (s == "--pld-adaptive"){ a.pld_adaptive = true; }
else if (s == "--pld-fixed-k") { a.pld_adaptive = false; } // opt out of adaptive
else if (s == "--pld-min-hist"){ auto v = next(s.c_str()); if (!v) return false; a.pld_min_hist = std::atoi(v); }
else if (s == "--pld-no-guard"){ a.pld_guard = false; }
else if (s == "--pld-guard-distinct"){ auto v = next(s.c_str()); if (!v) return false; a.pld_guard_distinct = std::atoi(v); }
else if (s == "--pld-guard-tail"){ auto v = next(s.c_str()); if (!v) return false; a.pld_guard_tail = std::atoi(v); }
else if (s == "--pld-loop-warn"){ auto v = next(s.c_str()); if (!v) return false; a.pld_loop_warn = std::atoi(v); }
else if (s == "--help" || s == "-h") {
printf("Usage: %s --model-dir <path> [options]\n", argv[0]);
printf(" --prompt \"text\" prompt text (default: \"%s\")\n", a.prompt.c_str());
printf(" --prompt-file FILE read prompt from file (overrides --prompt)\n");
printf(" --n-predict N max tokens to generate (default: %d)\n", a.n_predict);
printf(" --tp-size N tensor parallelism (default: 1; or TP_SIZE env)\n");
printf(" --num-layers N limit layers, testing only (default: all)\n");
printf(" --max-seq N KV cache + context cap (default: %d)\n", a.max_seq);
printf(" --chat apply Qwen3 chat template\n");
printf(" --system \"text\" system role for chat\n");
printf(" --temperature F 0 = greedy; typical 0.7\n");
printf(" --top-k N 0 = disabled\n");
printf(" --top-p F 1.0 = disabled; typical 0.8\n");
printf(" --seed N 0 = time-based (default)\n");
printf(" --no-stream batch-print final text\n");
printf(" -i, --interactive REPL (multi-turn memory when --chat)\n");
printf(" --reset force stateless REPL (reset each turn)\n");
printf(" --pld enable Prompt Lookup Decoding (greedy only)\n");
printf(" --pld-k N draft window size (default: 4)\n");
printf(" --pld-ngram N match n-gram size (default: 2; multi-level fallback)\n");
printf(" --pld-adaptive adjust K based on recent accept rate\n");
printf(" --pld-min-hist N skip PLD until history >= N tokens (default: 20)\n");
printf(" --pld-no-guard disable degeneration guard (dangerous: can amplify loops)\n");
printf(" --pld-guard-distinct N reject draft with distinct tokens < N (default: 3)\n");
printf(" --pld-guard-tail N reject draft if draft[0] matches all last N hist (default: 6)\n");
printf(" --pld-loop-warn N warn once on N consecutive identical emitted tokens (default: 8)\n");
return false;
}
else { fprintf(stderr, "unknown arg: %s\n", s.c_str()); return false; }
}
if (a.model_dir.empty()) { fprintf(stderr, "--model-dir required\n"); return false; }
if (const char* r = std::getenv("TP_RANK")) a.tp_rank = std::atoi(r);
if (const char* s = std::getenv("TP_SIZE")) a.tp_size = std::atoi(s);
return true;
}
// Sample next token from logits. temperature=0 → greedy argmax. Otherwise top-k / top-p.
static int sample_token(const std::vector<uint16_t>& logits_bf16, int64_t V,
float temperature, int top_k, float top_p, std::mt19937& rng) {
if (temperature <= 0.0f) {
int best = 0;
float bv = bf16_to_float(logits_bf16[0]);
for (int64_t i = 1; i < V; i++) {
float v = bf16_to_float(logits_bf16[i]);
if (v > bv) { bv = v; best = (int)i; }
}
return best;
}
// Build (logit, id) list as float
std::vector<std::pair<float, int>> scored;
scored.reserve(V);
for (int64_t i = 0; i < V; i++) {
scored.emplace_back(bf16_to_float(logits_bf16[i]) / temperature, (int)i);
}
// Top-k: keep highest k entries (partial sort)
if (top_k > 0 && top_k < (int)scored.size()) {
std::nth_element(scored.begin(), scored.begin() + top_k, scored.end(),
[](const auto& a, const auto& b){ return a.first > b.first; });
scored.resize(top_k);
}
// Sort descending for top-p
std::sort(scored.begin(), scored.end(),
[](const auto& a, const auto& b){ return a.first > b.first; });
// Softmax (numerically stable)
float maxv = scored[0].first;
double sum = 0;
for (auto& p : scored) { p.first = std::exp(p.first - maxv); sum += p.first; }
for (auto& p : scored) p.first /= (float)sum;
// Top-p nucleus
if (top_p > 0.0f && top_p < 1.0f) {
double cum = 0;
size_t cutoff = scored.size();
for (size_t i = 0; i < scored.size(); i++) {
cum += scored[i].first;
if (cum >= top_p) { cutoff = i + 1; break; }
}
scored.resize(cutoff);
// re-normalize
double s = 0; for (auto& p : scored) s += p.first;
for (auto& p : scored) p.first /= (float)s;
}
// Sample
std::uniform_real_distribution<float> U(0.0f, 1.0f);
float r = U(rng), acc = 0.0f;
for (auto& p : scored) {
acc += p.first;
if (r <= acc) return p.second;
}
return scored.back().second;
}
// Broadcast a prompt's token_ids from rank 0 to all ranks. For TP>1 the non-master ranks need
// the tokens before prefill. We use HCCL broadcast: rank 0 provides the count, then the ids.
// Uses a pre-allocated device buffer (must be large enough for max_seq tokens).
static bool broadcast_token_ids(Runner& runner, std::vector<int32_t>& ids,
int64_t max_seq, bool is_master) {
const ModelConfig& cfg = runner.cfg();
if (cfg.tp_size <= 1) return true;
// Step 1: broadcast count (as int32 on device)
DeviceBuffer cnt_dev(4);
int32_t cnt = is_master ? (int32_t)ids.size() : 0;
ACL_CHECK(aclrtMemcpy(cnt_dev.get(), 4, &cnt, 4, ACL_MEMCPY_HOST_TO_DEVICE));
// Access Runner's HCCL context via stream (exposed) and rely on the fact that ctx.comm is owned.
// Since hccl_broadcast needs HcclCtx, we need access. Cheapest: friend access via a shim member.
// For now, Runner has a stream() accessor; HCCL ctx is private. We'll accept that and broadcast
// via a direct call on the comm — but ctx is hidden. Workaround: expose hccl_ctx() on Runner.
// ... (see Runner::hccl_ctx() accessor added below)
extern HcclCtx* runner_hccl_ctx_shim(Runner& r); // forward from runner.cpp
HcclCtx* ctx = runner_hccl_ctx_shim(runner);
if (!ctx) return false;
if (!hccl_broadcast(*ctx, cnt_dev.get(), 1, HCCL_DATA_TYPE_INT32, 0, runner.stream())) return false;
ACL_CHECK(aclrtSynchronizeStream(runner.stream()));
ACL_CHECK(aclrtMemcpy(&cnt, 4, cnt_dev.get(), 4, ACL_MEMCPY_DEVICE_TO_HOST));
if (cnt <= 0 || cnt > (int32_t)max_seq) {
fprintf(stderr, "[rank %d] broadcast: bad count %d\n", cfg.tp_rank, cnt);
return false;
}
// Step 2: broadcast the id buffer
DeviceBuffer ids_dev(cnt * 4);
if (is_master) {
ACL_CHECK(aclrtMemcpy(ids_dev.get(), cnt*4, ids.data(), cnt*4, ACL_MEMCPY_HOST_TO_DEVICE));
}
if (!hccl_broadcast(*ctx, ids_dev.get(), cnt, HCCL_DATA_TYPE_INT32, 0, runner.stream())) return false;
ACL_CHECK(aclrtSynchronizeStream(runner.stream()));
if (!is_master) {
ids.resize(cnt);
ACL_CHECK(aclrtMemcpy(ids.data(), cnt*4, ids_dev.get(), cnt*4, ACL_MEMCPY_DEVICE_TO_HOST));
}
return true;
}
// Run one generation turn. Assumes KV cache is reset. Returns perf summary.
struct TurnStats {
double prefill_ms = 0; double decode_ms = 0;
int n_prompt = 0; int decoded = 0; bool hit_eos = false;
};
static TurnStats run_turn(Runner& runner, Tokenizer& tokenizer, const Args& args,
const std::string& prompt, std::mt19937& rng, bool is_master) {
TurnStats st;
// --- Tokenize (on master; broadcast for TP>1) ---
std::vector<int32_t> input_ids;
if (is_master) {
auto raw = tokenizer.encode_via_python(args.model_dir, prompt, args.chat_template);
if (raw.empty()) return st;
input_ids.reserve(raw.size());
for (int v : raw) input_ids.push_back((int32_t)v);
}
if (args.tp_size > 1) {
if (!broadcast_token_ids(runner, input_ids, args.max_seq, is_master)) return st;
}
if (input_ids.empty()) return st;
const int64_t V = runner.cfg().vocab_size;
std::vector<uint16_t> logits_h(V);
auto load_logits = [&](DeviceBuffer& buf) {
ACL_CHECK(aclrtMemcpy(logits_h.data(), V*2, buf.get(), V*2, ACL_MEMCPY_DEVICE_TO_HOST));
};
auto is_eos = [&](int id) {
for (int e : args.eos_ids) if (id == e) return true;
return false;
};
// --- Prefill ---
st.n_prompt = (int)input_ids.size();
auto t0 = std::chrono::steady_clock::now();
DeviceBuffer logits;
if (!runner.prefill(input_ids.data(), (int64_t)input_ids.size(), logits)) return st;
auto t1 = std::chrono::steady_clock::now();
st.prefill_ms = std::chrono::duration<double, std::milli>(t1 - t0).count();
load_logits(logits);
int next_id = sample_token(logits_h, V, args.temperature, args.top_k, args.top_p, rng);
if (is_master && args.stream) {
if (!args.chat_template) printf("%s", prompt.c_str());
printf("%s", tokenizer.decode(next_id).c_str());
fflush(stdout);
}
std::vector<int> generated = { next_id };
st.hit_eos = is_eos(next_id);
// All tokens (prompt + generated) for PLD n-gram lookup. Non-master ranks still need to
// track consistent length for HCCL broadcast of draft proposals.
std::vector<int32_t> hist;
hist.reserve(input_ids.size() + args.n_predict + 16);
for (auto x : input_ids) hist.push_back(x);
hist.push_back(next_id);
// PLD n-gram lookup: search for suffix match ending at end-of-hist; return K tokens following match.
// Longer matches = more reliable drafts. Multi-level: try n, fall back to smaller n if no match.
auto lookup_one = [&](int ngram, int K) -> std::vector<int32_t> {
int hs = (int)hist.size();
if (hs < ngram + 1 || K <= 0) return {};
for (int start = hs - ngram - 1; start >= 0; start--) {
bool match = true;
for (int k = 0; k < ngram; k++) {
if (hist[start + k] != hist[hs - ngram + k]) { match = false; break; }
}
if (match) {
int after = start + ngram;
std::vector<int32_t> d;
for (int k = 0; k < K && after + k < hs; k++) {
d.push_back(hist[after + k]);
if (is_eos(hist[after + k])) break;
}
if (!d.empty()) return d;
}
}
return {};
};
// Multi-level: try configured n first, then n-1, then n-2 (down to 1).
auto lookup_draft = [&](int ngram, int K) -> std::vector<int32_t> {
for (int n = ngram; n >= 1; n--) {
auto d = lookup_one(n, K);
if (!d.empty()) return d;
}
return {};
};
// Degeneration guard: classify a draft as repetition-induced so we can fall back to single
// decode (and avoid PLD amplifying model's own repetition loop into a runaway "W W W …" mess).
// Returns nullptr if draft is OK, else a short reason string for stats.
auto draft_degenerate = [&](const std::vector<int32_t>& d) -> const char* {
if (!args.pld_guard || d.empty()) return nullptr;
// (1) distinct-token count: a draft of K tokens with < args.pld_guard_distinct distinct
// values means n-gram is echoing a loop. Only apply when draft is long enough.
if ((int)d.size() >= 3) {
int distinct = 0;
for (int i = 0; i < (int)d.size(); i++) {
bool seen = false;
for (int j = 0; j < i; j++) { if (d[j] == d[i]) { seen = true; break; } }
if (!seen) distinct++;
}
if (distinct < args.pld_guard_distinct) return "low-distinct";
}
// (2) tail echo: if the last N hist tokens are all equal to draft[0], the model is already
// in a short loop — accepting the draft will just confirm the loop at batch speed.
int tail_n = std::min(args.pld_guard_tail, (int)hist.size());
if (tail_n >= 3) {
int matches = 0;
for (int i = (int)hist.size() - tail_n; i < (int)hist.size(); i++) {
if (hist[i] == d[0]) matches++;
}
if (matches == tail_n) return "tail-echo";
}
return nullptr;
};
// --- Decode loop ---
auto t2 = std::chrono::steady_clock::now();
int pld_verifies = 0, pld_accepted = 0;
int pld_rej_lowdist = 0, pld_rej_tailecho = 0; // guard rejection counters
bool loop_warned = false; // warn-once state
// Adaptive K state: recent accept counts for moving-average decisions
const int ADAPT_WINDOW = 8;
std::vector<int> recent_accepts;
int current_k = args.pld_k;
bool pld_disabled_adapt = false; // set true when recent accept rate is too low to benefit
while (st.decoded < args.n_predict - 1 && !st.hit_eos) {
// Adaptive K: scale K with recent accept rate.
// No auto-disable: since S=K+1 forward ≈ S=1 forward (latency-bound), even accept=0.1
// still nets slightly positive — PLD doesn't "hurt" as long as ngram lookup is cheap.
if (args.pld_adaptive && (int)recent_accepts.size() >= ADAPT_WINDOW) {
double avg = 0;
for (int a : recent_accepts) avg += a;
avg /= recent_accepts.size();
// Aim: K = 2*avg + 4 (generous window to catch upswings). Clamp [4, 12].
current_k = std::max(4, std::min(12, (int)std::round(2.0 * avg + 4.0)));
}
// Try PLD speculation path — skip until enough history accumulated
std::vector<int32_t> draft;
if (args.pld_enabled && (int)hist.size() >= args.pld_min_hist && is_master) {
draft = lookup_draft(args.pld_ngram, current_k);
// Degeneration guard: if draft looks like repetition-loop echo, drop it so this
// iteration falls through to normal single decode. This does NOT stop a loop the model
// is already in (greedy is deterministic), but it prevents PLD from running the loop
// at batch speed while masquerading as a speedup.
if (!draft.empty()) {
const char* reason = draft_degenerate(draft);
if (reason) {
if (reason[0] == 'l') pld_rej_lowdist++;
else pld_rej_tailecho++;
draft.clear();
}
}
}
// For TP>1, broadcast draft across ranks. Only broadcast if master has a non-empty draft;
// otherwise all ranks take the no-draft path (normal decode).
bool has_draft = is_master ? !draft.empty() : false;
// Broadcast the "has_draft" flag (using a 1-element count: 1 = yes, 0 = no)
if (args.tp_size > 1) {
extern HcclCtx* runner_hccl_ctx_shim(Runner&);
HcclCtx* ctx = runner_hccl_ctx_shim(runner);
DeviceBuffer flag(4);
int32_t f = has_draft ? 1 : 0;
ACL_CHECK(aclrtMemcpy(flag.get(), 4, &f, 4, ACL_MEMCPY_HOST_TO_DEVICE));
hccl_broadcast(*ctx, flag.get(), 1, HCCL_DATA_TYPE_INT32, 0, runner.stream());
ACL_CHECK(aclrtSynchronizeStream(runner.stream()));
ACL_CHECK(aclrtMemcpy(&f, 4, flag.get(), 4, ACL_MEMCPY_DEVICE_TO_HOST));
has_draft = (f != 0);
if (has_draft) {
std::vector<int32_t> d = draft;
broadcast_token_ids(runner, d, args.max_seq, is_master);
draft = d;
} else {
draft.clear();
}
}
if (args.pld_enabled && (int)draft.size() >= 1 && args.temperature == 0.0f) {
// Batch verify: input = [next_id, draft[0], ..., draft[K-1]]
std::vector<int32_t> batch_input = { next_id };
for (auto d : draft) batch_input.push_back(d);
int S = (int)batch_input.size();
DeviceBuffer batch_logits;
if (!runner.decode_batch(batch_input.data(), S, batch_logits)) break;
std::vector<uint16_t> blh(S * V);
if (is_master) ACL_CHECK(aclrtMemcpy(blh.data(), S*V*2, batch_logits.get(), S*V*2, ACL_MEMCPY_DEVICE_TO_HOST));
// Accept longest prefix: draft[i] is "candidate" for position past+i+1.
// blh row i predicts position past+i+1 (follows batch_input[i]).
// Verify: blh[0].argmax == draft[0]? (i.e., does model agree with draft's first proposal)
int accept = 0, new_next = next_id;
if (is_master) {
for (int i = 0; i < S - 1; i++) {
int pred = 0; float bv = bf16_to_float(blh[i * V]);
for (int k = 1; k < V; k++) { float v = bf16_to_float(blh[i*V + k]); if (v > bv) { bv = v; pred = k; } }
if (pred == (int)draft[i]) accept++;
else { new_next = pred; break; }
}
if (accept == S - 1) {
// All draft accepted, bonus from last row
int pred = 0; float bv = bf16_to_float(blh[(S-1) * V]);
for (int k = 1; k < V; k++) { float v = bf16_to_float(blh[(S-1)*V + k]); if (v > bv) { bv = v; pred = k; } }
new_next = pred;
}
}
// Broadcast accept count + new_next across TP ranks
if (args.tp_size > 1) {
int32_t packed[2] = { (int32_t)accept, (int32_t)new_next };
std::vector<int32_t> p(packed, packed + 2);
broadcast_token_ids(runner, p, args.max_seq, is_master);
if (p.size() == 2) { accept = p[0]; new_next = p[1]; }
}
// Rewind KV for rejected drafts
int64_t rewind = (int64_t)(S - 1 - accept); // drafts not accepted (excluding bonus)
if (rewind > 0) runner.rewind_cache(rewind);
// Commit accepted drafts + bonus to hist and emit
for (int i = 0; i < accept; i++) {
int tok = (int)draft[i];
hist.push_back(tok);
generated.push_back(tok);
st.decoded++;
if (is_master && args.stream) { printf("%s", tokenizer.decode(tok).c_str()); fflush(stdout); }
if (is_eos(tok)) { st.hit_eos = true; break; }
}
pld_verifies++; pld_accepted += accept;
// Track recent accept for adaptive K
if (args.pld_adaptive) {
recent_accepts.push_back(accept);
if ((int)recent_accepts.size() > ADAPT_WINDOW) recent_accepts.erase(recent_accepts.begin());
}
if (st.hit_eos) break;
// Bonus token (new_next) is also committed
hist.push_back(new_next);
generated.push_back(new_next);
st.decoded++;
if (is_master && args.stream) { printf("%s", tokenizer.decode(new_next).c_str()); fflush(stdout); }
if (is_eos(new_next)) { st.hit_eos = true; break; }
next_id = new_next;
} else {
// Normal decode
DeviceBuffer logits2;
if (!runner.decode((int32_t)next_id, logits2)) break;
load_logits(logits2);
next_id = sample_token(logits_h, V, args.temperature, args.top_k, args.top_p, rng);
hist.push_back(next_id);
generated.push_back(next_id);
st.decoded++;
if (is_master && args.stream) { printf("%s", tokenizer.decode(next_id).c_str()); fflush(stdout); }
if (is_eos(next_id)) { st.hit_eos = true; break; }
}
// Loop-warn: emit a one-shot warning to stderr if the tail of generated is all-same-token.
// Does not stop generation (user may want to see what happens) — just flags output quality.
if (is_master && !loop_warned && args.pld_loop_warn > 0 &&
(int)generated.size() >= args.pld_loop_warn) {
int tail = args.pld_loop_warn;
int anchor = generated[(int)generated.size() - tail];
bool all_same = true;
for (int i = (int)generated.size() - tail + 1; i < (int)generated.size(); i++) {
if (generated[i] != anchor) { all_same = false; break; }
}
if (all_same) {
fprintf(stderr, "\n[warn] %d consecutive identical tokens — likely degeneration loop; output after this point is suspect\n", tail);
loop_warned = true;
}
}
}
auto t3 = std::chrono::steady_clock::now();
st.decode_ms = std::chrono::duration<double, std::milli>(t3 - t2).count();
if (is_master && args.pld_enabled) {
if (pld_verifies > 0) {
fprintf(stderr, "\n[pld] %d verifies, %d drafts accepted, avg=%.2f",
pld_verifies, pld_accepted, (double)pld_accepted / pld_verifies);
} else {
fprintf(stderr, "\n[pld] 0 verifies (all drafts blocked or none found)");
}
if (args.pld_guard && (pld_rej_lowdist + pld_rej_tailecho) > 0) {
fprintf(stderr, "; guard rejections: low-distinct=%d tail-echo=%d",
pld_rej_lowdist, pld_rej_tailecho);
}
fprintf(stderr, "\n");
}
if (is_master) {
if (args.stream) { printf("\n"); fflush(stdout); }
else {
std::string text = tokenizer.decode(generated);
printf("%s%s\n", args.chat_template ? "" : prompt.c_str(), text.c_str());
}
}
return st;
}
static bool load_file(const std::string& path, std::string& out) {
FILE* f = fopen(path.c_str(), "rb");
if (!f) { fprintf(stderr, "[cli] cannot open %s\n", path.c_str()); return false; }
fseek(f, 0, SEEK_END); long sz = ftell(f); fseek(f, 0, SEEK_SET);
out.resize(sz);
size_t n = fread(out.data(), 1, sz, f);
fclose(f);
if ((long)n != sz) { fprintf(stderr, "[cli] short read from %s\n", path.c_str()); return false; }
// Strip a single trailing newline (common in text files)
if (!out.empty() && out.back() == '\n') out.pop_back();
return true;
}
int main(int argc, char** argv) {
Args args;
if (!parse_args(argc, argv, args)) return 1;
// --prompt-file overrides --prompt
if (!args.prompt_file.empty()) {
if (!load_file(args.prompt_file, args.prompt)) return 1;
}
const bool is_master = (args.tp_rank == 0);
std::mt19937 rng(args.seed ? args.seed :
(uint64_t)std::chrono::steady_clock::now().time_since_epoch().count());
if (is_master) {
printf("[cli] model=%s\n", args.model_dir.c_str());
printf("[cli] tp=%d n_predict=%d temp=%.2f top_k=%d top_p=%.2f chat=%d interactive=%d\n",
args.tp_size, args.n_predict, args.temperature, args.top_k, args.top_p,
args.chat_template, args.interactive);
fflush(stdout);
}
Tokenizer tokenizer;
if (!tokenizer.load(args.vocab_path)) {
fprintf(stderr, "[cli] failed to load vocab %s\n", args.vocab_path.c_str()); return 1;
}
Runner runner;
int num_layers = args.num_layers;
if (num_layers == 0) {
ModelConfig probe;
if (!probe.load_from_json(args.model_dir + "/config.json")) return 1;
num_layers = (int)probe.num_hidden_layers;
}
if (!runner.init(args.model_dir, args.tp_size, args.tp_rank,
num_layers, args.max_seq, args.device_id)) return 1;
if (const char* p = std::getenv("LCA_PROFILE"); p && std::atoi(p) != 0) {
runner.profile_enabled = true;
}
// Warmup: cut cold-start latency. Controlled via LCA_WARMUP env (default 0 to keep behavior).
if (const char* w = std::getenv("LCA_WARMUP"); w) {
int n = std::atoi(w);
if (n > 0) runner.warmup(n);
}
if (args.interactive) {
const bool multi_turn = args.chat_template && !args.reset_each_turn;
if (is_master) {
printf("\n[cli] === interactive mode ===\n");
if (multi_turn) {
printf("[cli] multi-turn chat (KV cache preserved). Commands: 'quit', 'reset'.\n");
if (!args.system_prompt.empty()) {
printf("[cli] system: %s\n", args.system_prompt.c_str());
}
} else {
printf("[cli] stateless mode (KV cache reset each turn). Command: 'quit'.\n");
if (!args.chat_template) {
printf("[cli] (hint: add --chat for multi-turn conversational memory)\n");
}
}
fflush(stdout);
}
// Conversation history: accumulated (role, content) pairs. System prompt seeded if present.
std::vector<std::pair<std::string, std::string>> conversation;
if (multi_turn && !args.system_prompt.empty()) {
conversation.emplace_back("system", args.system_prompt);
}
auto* hccl_ctx = runner_hccl_ctx_shim(runner);
// Signal types (broadcast as int32): 0 = normal turn, 1 = quit, 2 = reset.
auto broadcast_signal = [&](int32_t sig)->int32_t {
if (args.tp_size <= 1) return sig;
DeviceBuffer s(4);
ACL_CHECK(aclrtMemcpy(s.get(), 4, &sig, 4, ACL_MEMCPY_HOST_TO_DEVICE));
hccl_broadcast(*hccl_ctx, s.get(), 1, HCCL_DATA_TYPE_INT32, 0, runner.stream());
ACL_CHECK(aclrtSynchronizeStream(runner.stream()));
int32_t r; ACL_CHECK(aclrtMemcpy(&r, 4, s.get(), 4, ACL_MEMCPY_DEVICE_TO_HOST));
return r;
};
while (true) {
std::string prompt;
int32_t sig = 0;
if (is_master) {
printf("\n> "); fflush(stdout);
if (!std::getline(std::cin, prompt)) sig = 1;
else if (prompt == "quit" || prompt == "exit") sig = 1;
else if (prompt == "reset") sig = 2;
else if (prompt.empty()) sig = 3; // skip
}
sig = broadcast_signal(sig);
if (sig == 1) break;
if (sig == 2) {
runner.reset_cache();
conversation.clear();
if (multi_turn && !args.system_prompt.empty())
conversation.emplace_back("system", args.system_prompt);
if (is_master) { printf("[cli] (cache + conversation reset)\n"); fflush(stdout); }
continue;
}
if (sig == 3) continue;
TurnStats st;
if (multi_turn) {
// Append user message and tokenize full conversation. Prefill DELTA only.
if (is_master) conversation.emplace_back("user", prompt);
// Also ranks 1..N-1 need to track conversation (needed for correct delta count on
// subsequent turns if TP ever tokenizes per-rank — currently rank 0 tokenizes).
std::vector<int32_t> full_ids;
if (is_master) {
auto raw = tokenizer.encode_conversation_via_python(args.model_dir, conversation, /*gen_prompt=*/true);
full_ids.reserve(raw.size());
for (int v : raw) full_ids.push_back((int32_t)v);
}
// Broadcast full_ids (variable-length). Use the same shim as broadcast_token_ids.
if (args.tp_size > 1) {
if (!broadcast_token_ids(runner, full_ids, args.max_seq, is_master)) break;
}
if (full_ids.empty()) { if (is_master) printf("[cli] tokenize failed\n"); continue; }
int64_t past = runner.past_len();
if ((int64_t)full_ids.size() < past) { runner.reset_cache(); past = 0; }
std::vector<int32_t> delta(full_ids.begin() + past, full_ids.end());
if (delta.empty()) {
if (is_master) printf("[cli] (no new tokens)\n");
continue;
}
// Overflow check — simple policy: warn + auto-reset if the turn + generation
// would exceed max_seq. Conversation history is cleared (except --system) so
// the user's current prompt still fits.
if ((int64_t)(past + delta.size()) + args.n_predict > args.max_seq) {
if (is_master) {
fprintf(stderr, "[cli] context %ld + gen %d > max_seq %d — auto-resetting\n",
(long)(past + delta.size()), args.n_predict, args.max_seq);
}
runner.reset_cache();
// Rebuild conversation: keep only system + current user turn.
if (is_master) {
std::vector<std::pair<std::string, std::string>> fresh;
for (auto& m : conversation) if (m.first == "system") fresh.push_back(m);
if (!conversation.empty() && conversation.back().first == "user") {
fresh.push_back(conversation.back());
}
conversation = std::move(fresh);
auto raw = tokenizer.encode_conversation_via_python(args.model_dir, conversation, true);
full_ids.clear();
for (int v : raw) full_ids.push_back((int32_t)v);
}
if (args.tp_size > 1) {
if (!broadcast_token_ids(runner, full_ids, args.max_seq, is_master)) break;
}
delta.assign(full_ids.begin(), full_ids.end());
past = 0;
}
// --- Prefill the delta ---
st.n_prompt = (int)delta.size();
auto t0 = std::chrono::steady_clock::now();
DeviceBuffer logits;
if (!runner.prefill(delta.data(), (int64_t)delta.size(), logits)) break;
auto t1 = std::chrono::steady_clock::now();
st.prefill_ms = std::chrono::duration<double, std::milli>(t1 - t0).count();
const int64_t V = runner.cfg().vocab_size;
std::vector<uint16_t> logits_h(V);
auto load_logits = [&](DeviceBuffer& buf) {
ACL_CHECK(aclrtMemcpy(logits_h.data(), V*2, buf.get(), V*2, ACL_MEMCPY_DEVICE_TO_HOST));
};
auto is_eos = [&](int id) {
for (int e : args.eos_ids) if (id == e) return true;
return false;
};
load_logits(logits);
int next_id = sample_token(logits_h, V, args.temperature, args.top_k, args.top_p, rng);
std::vector<int> assistant_ids = { next_id };
if (is_master) { printf("%s", tokenizer.decode(next_id).c_str()); fflush(stdout); }
st.hit_eos = is_eos(next_id);
auto t2 = std::chrono::steady_clock::now();
for (int step = 1; step < args.n_predict && !st.hit_eos; step++) {
DeviceBuffer logits2;
if (!runner.decode((int32_t)next_id, logits2)) break;
load_logits(logits2);
next_id = sample_token(logits_h, V, args.temperature, args.top_k, args.top_p, rng);
assistant_ids.push_back(next_id);
st.decoded++;
if (is_master) { printf("%s", tokenizer.decode(next_id).c_str()); fflush(stdout); }
if (is_eos(next_id)) { st.hit_eos = true; break; }
}
auto t3 = std::chrono::steady_clock::now();
st.decode_ms = std::chrono::duration<double, std::milli>(t3 - t2).count();
if (is_master) { printf("\n"); fflush(stdout); }
// Record assistant reply in conversation (strip trailing EOS before decode,
// and trim incomplete UTF-8 tail if generation was cut mid-codepoint).
if (is_master) {
std::vector<int> content_ids;
for (int id : assistant_ids) { if (is_eos(id)) break; content_ids.push_back(id); }
conversation.emplace_back("assistant", utf8_trim_incomplete(tokenizer.decode(content_ids)));
}
} else {
// Stateless: reset cache, one-shot prompt
runner.reset_cache();
st = run_turn(runner, tokenizer, args, prompt, rng, is_master);
}
if (is_master) {
double tgs = (st.decode_ms > 0) ? (st.decoded * 1000.0 / st.decode_ms) : 0.0;
printf("[perf] prefill %d tok %.0fms decode %d tok %.0fms = %.2f t/s%s past_len=%ld\n",
st.n_prompt, st.prefill_ms, st.decoded, st.decode_ms, tgs,
st.hit_eos ? " (EOS)" : "", runner.past_len());
fflush(stdout);
}
}
if (is_master) printf("[cli] bye\n");
return 0;
}
// One-shot mode
TurnStats st = run_turn(runner, tokenizer, args, args.prompt, rng, is_master);
if (is_master) runner.print_profile_summary();
if (is_master) {
if (st.hit_eos) printf("[cli] (hit EOS)\n");
printf("\n[perf] prefill: %.1fms for %d tokens = %.2f t/s\n",
st.prefill_ms, st.n_prompt,
(st.prefill_ms > 0) ? (st.n_prompt * 1000.0 / st.prefill_ms) : 0.0);
if (st.decoded > 0) {
printf("[perf] decode : %.1fms for %d tokens = %.2f t/s (TG)\n",
st.decode_ms, st.decoded, (st.decoded * 1000.0) / st.decode_ms);
}
}
return 0;
}