// main_cli.cpp — qwen3-moe-aclnn entry point. // // Usage: // qwen3-moe-aclnn --model-dir --prompt "" --n-predict // [--tp-size 1|16] [--vocab ] [--max-seq N] [--num-layers N] // [--chat] [--temperature 0.7] [--top-k 20] [--top-p 0.8] [--seed N] // [--no-stream] // // At TP>1 each rank is a separate process (env TP_RANK=, TP_SIZE=) launched by // scripts/tp_launch.sh. Only rank 0 prints text output. #include "runner.h" #include "tokenizer.h" // Escape hatch for HCCL broadcast from within CLI (defined in runner.cpp) HcclCtx* runner_hccl_ctx_shim(Runner& r); #include #include #include #include #include #include #include #include #include #include static float bf16_to_float(uint16_t x) { uint32_t u = (uint32_t)x << 16; float f; std::memcpy(&f, &u, 4); return f; } // Truncate a string to the last complete UTF-8 character boundary. If the last 1-3 bytes // form an incomplete multi-byte sequence (e.g., assistant response cut mid-codepoint at // n_predict limit), drop them so the JSON encoder downstream sees only valid UTF-8. static std::string utf8_trim_incomplete(const std::string& s) { if (s.empty()) return s; size_t n = s.size(); // Walk back up to 4 bytes looking for the start of a UTF-8 sequence. for (size_t back = 0; back < 4 && back < n; back++) { size_t i = n - 1 - back; unsigned char c = (unsigned char)s[i]; if ((c & 0x80) == 0) { return s; } // ASCII: already complete if ((c & 0xC0) == 0x80) { continue; } // continuation byte: keep going // Start byte: 110xxxxx (2-byte), 1110xxxx (3-byte), 11110xxx (4-byte) size_t need = 0; if ((c & 0xE0) == 0xC0) need = 2; else if ((c & 0xF0) == 0xE0) need = 3; else if ((c & 0xF8) == 0xF0) need = 4; else return s.substr(0, i); // invalid start — drop size_t have = back + 1; return (have >= need) ? s : s.substr(0, i); // trim incomplete trailing sequence } // Should not reach here; return as-is. return s; } struct Args { std::string model_dir; std::string prompt = "The capital of France is"; std::string vocab_path = "tokenizer_data/vocab.bin"; int n_predict = 100; int tp_size = 1; int tp_rank = 0; int num_layers = 0; // 0 = auto int max_seq = 512; int device_id = 0; bool chat_template = false; bool stream = true; bool interactive = false; bool reset_each_turn = false; // if true, REPL clears KV cache between turns (stateless) std::string system_prompt; // optional system role for chat mode std::string prompt_file; // read prompt from file (avoids shell escaping) bool pld_enabled = false; // prompt lookup decoding int pld_k = 10; // bench_pld_k.sh: K=10 median 105 t/s (3/3 runs 100+), K=8 was 35 int pld_ngram = 1; // n-gram match size — 1 with multi-level fallback best bool pld_adaptive = false; // fixed K=10 is simpler and mean-optimal; adaptive --pld-adaptive int pld_min_hist = 20; // skip PLD until history >= this (avoid early-token false matches) // PLD degeneration guard (on by default): prevents PLD from amplifying repetition loops. bool pld_guard = true; // --pld-no-guard disables int pld_guard_distinct = 3; // reject draft if distinct tokens < this (≥K/3 heuristic) int pld_guard_tail = 6; // reject if draft[0] matches all last N hist tokens int pld_loop_warn = 8; // warn once when N consecutive identical tokens emitted float temperature = 0.0f; // 0 = greedy int top_k = 0; // 0 = disabled float top_p = 1.0f; // 1.0 = disabled uint64_t seed = 0; // 0 = use time // Qwen3 EOS tokens (from generation_config.json) std::vector eos_ids = {151645, 151643}; }; static bool parse_args(int argc, char** argv, Args& a) { for (int i = 1; i < argc; i++) { std::string s = argv[i]; auto next = [&](const char* f)->const char* { if (i + 1 >= argc) { fprintf(stderr, "missing value for %s\n", f); return nullptr; } return argv[++i]; }; if (s == "--model-dir") { auto v = next(s.c_str()); if (!v) return false; a.model_dir = v; } else if (s == "--prompt") { auto v = next(s.c_str()); if (!v) return false; a.prompt = v; } else if (s == "--vocab") { auto v = next(s.c_str()); if (!v) return false; a.vocab_path = v; } else if (s == "--n-predict") { auto v = next(s.c_str()); if (!v) return false; a.n_predict = std::atoi(v); } else if (s == "--tp-size") { auto v = next(s.c_str()); if (!v) return false; a.tp_size = std::atoi(v); } else if (s == "--num-layers") { auto v = next(s.c_str()); if (!v) return false; a.num_layers = std::atoi(v); } else if (s == "--max-seq") { auto v = next(s.c_str()); if (!v) return false; a.max_seq = std::atoi(v); } else if (s == "--device") { auto v = next(s.c_str()); if (!v) return false; a.device_id = std::atoi(v); } else if (s == "--temperature") { auto v = next(s.c_str()); if (!v) return false; a.temperature = (float)std::atof(v); } else if (s == "--top-k") { auto v = next(s.c_str()); if (!v) return false; a.top_k = std::atoi(v); } else if (s == "--top-p") { auto v = next(s.c_str()); if (!v) return false; a.top_p = (float)std::atof(v); } else if (s == "--seed") { auto v = next(s.c_str()); if (!v) return false; a.seed = (uint64_t)std::atoll(v); } else if (s == "--chat") { a.chat_template = true; } else if (s == "--no-stream") { a.stream = false; } else if (s == "--interactive" || s == "-i") { a.interactive = true; } else if (s == "--reset") { a.reset_each_turn = true; } else if (s == "--system") { auto v = next(s.c_str()); if (!v) return false; a.system_prompt = v; } else if (s == "--prompt-file") { auto v = next(s.c_str()); if (!v) return false; a.prompt_file = v; } else if (s == "--pld") { a.pld_enabled = true; } else if (s == "--pld-k") { auto v = next(s.c_str()); if (!v) return false; a.pld_k = std::atoi(v); } else if (s == "--pld-ngram") { auto v = next(s.c_str()); if (!v) return false; a.pld_ngram = std::atoi(v); } else if (s == "--pld-adaptive"){ a.pld_adaptive = true; } else if (s == "--pld-fixed-k") { a.pld_adaptive = false; } // opt out of adaptive else if (s == "--pld-min-hist"){ auto v = next(s.c_str()); if (!v) return false; a.pld_min_hist = std::atoi(v); } else if (s == "--pld-no-guard"){ a.pld_guard = false; } else if (s == "--pld-guard-distinct"){ auto v = next(s.c_str()); if (!v) return false; a.pld_guard_distinct = std::atoi(v); } else if (s == "--pld-guard-tail"){ auto v = next(s.c_str()); if (!v) return false; a.pld_guard_tail = std::atoi(v); } else if (s == "--pld-loop-warn"){ auto v = next(s.c_str()); if (!v) return false; a.pld_loop_warn = std::atoi(v); } else if (s == "--help" || s == "-h") { printf("Usage: %s --model-dir [options]\n", argv[0]); printf(" --prompt \"text\" prompt text (default: \"%s\")\n", a.prompt.c_str()); printf(" --prompt-file FILE read prompt from file (overrides --prompt)\n"); printf(" --n-predict N max tokens to generate (default: %d)\n", a.n_predict); printf(" --tp-size N tensor parallelism (default: 1; or TP_SIZE env)\n"); printf(" --num-layers N limit layers, testing only (default: all)\n"); printf(" --max-seq N KV cache + context cap (default: %d)\n", a.max_seq); printf(" --chat apply Qwen3 chat template\n"); printf(" --system \"text\" system role for chat\n"); printf(" --temperature F 0 = greedy; typical 0.7\n"); printf(" --top-k N 0 = disabled\n"); printf(" --top-p F 1.0 = disabled; typical 0.8\n"); printf(" --seed N 0 = time-based (default)\n"); printf(" --no-stream batch-print final text\n"); printf(" -i, --interactive REPL (multi-turn memory when --chat)\n"); printf(" --reset force stateless REPL (reset each turn)\n"); printf(" --pld enable Prompt Lookup Decoding (greedy only)\n"); printf(" --pld-k N draft window size (default: 4)\n"); printf(" --pld-ngram N match n-gram size (default: 2; multi-level fallback)\n"); printf(" --pld-adaptive adjust K based on recent accept rate\n"); printf(" --pld-min-hist N skip PLD until history >= N tokens (default: 20)\n"); printf(" --pld-no-guard disable degeneration guard (dangerous: can amplify loops)\n"); printf(" --pld-guard-distinct N reject draft with distinct tokens < N (default: 3)\n"); printf(" --pld-guard-tail N reject draft if draft[0] matches all last N hist (default: 6)\n"); printf(" --pld-loop-warn N warn once on N consecutive identical emitted tokens (default: 8)\n"); return false; } else { fprintf(stderr, "unknown arg: %s\n", s.c_str()); return false; } } if (a.model_dir.empty()) { fprintf(stderr, "--model-dir required\n"); return false; } if (const char* r = std::getenv("TP_RANK")) a.tp_rank = std::atoi(r); if (const char* s = std::getenv("TP_SIZE")) a.tp_size = std::atoi(s); return true; } // Sample next token from logits. temperature=0 → greedy argmax. Otherwise top-k / top-p. static int sample_token(const std::vector& logits_bf16, int64_t V, float temperature, int top_k, float top_p, std::mt19937& rng) { if (temperature <= 0.0f) { int best = 0; float bv = bf16_to_float(logits_bf16[0]); for (int64_t i = 1; i < V; i++) { float v = bf16_to_float(logits_bf16[i]); if (v > bv) { bv = v; best = (int)i; } } return best; } // Build (logit, id) list as float std::vector> scored; scored.reserve(V); for (int64_t i = 0; i < V; i++) { scored.emplace_back(bf16_to_float(logits_bf16[i]) / temperature, (int)i); } // Top-k: keep highest k entries (partial sort) if (top_k > 0 && top_k < (int)scored.size()) { std::nth_element(scored.begin(), scored.begin() + top_k, scored.end(), [](const auto& a, const auto& b){ return a.first > b.first; }); scored.resize(top_k); } // Sort descending for top-p std::sort(scored.begin(), scored.end(), [](const auto& a, const auto& b){ return a.first > b.first; }); // Softmax (numerically stable) float maxv = scored[0].first; double sum = 0; for (auto& p : scored) { p.first = std::exp(p.first - maxv); sum += p.first; } for (auto& p : scored) p.first /= (float)sum; // Top-p nucleus if (top_p > 0.0f && top_p < 1.0f) { double cum = 0; size_t cutoff = scored.size(); for (size_t i = 0; i < scored.size(); i++) { cum += scored[i].first; if (cum >= top_p) { cutoff = i + 1; break; } } scored.resize(cutoff); // re-normalize double s = 0; for (auto& p : scored) s += p.first; for (auto& p : scored) p.first /= (float)s; } // Sample std::uniform_real_distribution U(0.0f, 1.0f); float r = U(rng), acc = 0.0f; for (auto& p : scored) { acc += p.first; if (r <= acc) return p.second; } return scored.back().second; } // Broadcast a prompt's token_ids from rank 0 to all ranks. For TP>1 the non-master ranks need // the tokens before prefill. We use HCCL broadcast: rank 0 provides the count, then the ids. // Uses a pre-allocated device buffer (must be large enough for max_seq tokens). static bool broadcast_token_ids(Runner& runner, std::vector& ids, int64_t max_seq, bool is_master) { const ModelConfig& cfg = runner.cfg(); if (cfg.tp_size <= 1) return true; // Step 1: broadcast count (as int32 on device) DeviceBuffer cnt_dev(4); int32_t cnt = is_master ? (int32_t)ids.size() : 0; ACL_CHECK(aclrtMemcpy(cnt_dev.get(), 4, &cnt, 4, ACL_MEMCPY_HOST_TO_DEVICE)); // Access Runner's HCCL context via stream (exposed) and rely on the fact that ctx.comm is owned. // Since hccl_broadcast needs HcclCtx, we need access. Cheapest: friend access via a shim member. // For now, Runner has a stream() accessor; HCCL ctx is private. We'll accept that and broadcast // via a direct call on the comm — but ctx is hidden. Workaround: expose hccl_ctx() on Runner. // ... (see Runner::hccl_ctx() accessor added below) extern HcclCtx* runner_hccl_ctx_shim(Runner& r); // forward from runner.cpp HcclCtx* ctx = runner_hccl_ctx_shim(runner); if (!ctx) return false; if (!hccl_broadcast(*ctx, cnt_dev.get(), 1, HCCL_DATA_TYPE_INT32, 0, runner.stream())) return false; ACL_CHECK(aclrtSynchronizeStream(runner.stream())); ACL_CHECK(aclrtMemcpy(&cnt, 4, cnt_dev.get(), 4, ACL_MEMCPY_DEVICE_TO_HOST)); if (cnt <= 0 || cnt > (int32_t)max_seq) { fprintf(stderr, "[rank %d] broadcast: bad count %d\n", cfg.tp_rank, cnt); return false; } // Step 2: broadcast the id buffer DeviceBuffer ids_dev(cnt * 4); if (is_master) { ACL_CHECK(aclrtMemcpy(ids_dev.get(), cnt*4, ids.data(), cnt*4, ACL_MEMCPY_HOST_TO_DEVICE)); } if (!hccl_broadcast(*ctx, ids_dev.get(), cnt, HCCL_DATA_TYPE_INT32, 0, runner.stream())) return false; ACL_CHECK(aclrtSynchronizeStream(runner.stream())); if (!is_master) { ids.resize(cnt); ACL_CHECK(aclrtMemcpy(ids.data(), cnt*4, ids_dev.get(), cnt*4, ACL_MEMCPY_DEVICE_TO_HOST)); } return true; } // Run one generation turn. Assumes KV cache is reset. Returns perf summary. struct TurnStats { double prefill_ms = 0; double decode_ms = 0; int n_prompt = 0; int decoded = 0; bool hit_eos = false; }; static TurnStats run_turn(Runner& runner, Tokenizer& tokenizer, const Args& args, const std::string& prompt, std::mt19937& rng, bool is_master) { TurnStats st; // --- Tokenize (on master; broadcast for TP>1) --- std::vector input_ids; if (is_master) { auto raw = tokenizer.encode_via_python(args.model_dir, prompt, args.chat_template); if (raw.empty()) return st; input_ids.reserve(raw.size()); for (int v : raw) input_ids.push_back((int32_t)v); } if (args.tp_size > 1) { if (!broadcast_token_ids(runner, input_ids, args.max_seq, is_master)) return st; } if (input_ids.empty()) return st; const int64_t V = runner.cfg().vocab_size; std::vector logits_h(V); auto load_logits = [&](DeviceBuffer& buf) { ACL_CHECK(aclrtMemcpy(logits_h.data(), V*2, buf.get(), V*2, ACL_MEMCPY_DEVICE_TO_HOST)); }; auto is_eos = [&](int id) { for (int e : args.eos_ids) if (id == e) return true; return false; }; // --- Prefill --- st.n_prompt = (int)input_ids.size(); auto t0 = std::chrono::steady_clock::now(); DeviceBuffer logits; if (!runner.prefill(input_ids.data(), (int64_t)input_ids.size(), logits)) return st; auto t1 = std::chrono::steady_clock::now(); st.prefill_ms = std::chrono::duration(t1 - t0).count(); load_logits(logits); int next_id = sample_token(logits_h, V, args.temperature, args.top_k, args.top_p, rng); if (is_master && args.stream) { if (!args.chat_template) printf("%s", prompt.c_str()); printf("%s", tokenizer.decode(next_id).c_str()); fflush(stdout); } std::vector generated = { next_id }; st.hit_eos = is_eos(next_id); // All tokens (prompt + generated) for PLD n-gram lookup. Non-master ranks still need to // track consistent length for HCCL broadcast of draft proposals. std::vector hist; hist.reserve(input_ids.size() + args.n_predict + 16); for (auto x : input_ids) hist.push_back(x); hist.push_back(next_id); // PLD n-gram lookup: search for suffix match ending at end-of-hist; return K tokens following match. // Longer matches = more reliable drafts. Multi-level: try n, fall back to smaller n if no match. auto lookup_one = [&](int ngram, int K) -> std::vector { int hs = (int)hist.size(); if (hs < ngram + 1 || K <= 0) return {}; for (int start = hs - ngram - 1; start >= 0; start--) { bool match = true; for (int k = 0; k < ngram; k++) { if (hist[start + k] != hist[hs - ngram + k]) { match = false; break; } } if (match) { int after = start + ngram; std::vector d; for (int k = 0; k < K && after + k < hs; k++) { d.push_back(hist[after + k]); if (is_eos(hist[after + k])) break; } if (!d.empty()) return d; } } return {}; }; // Multi-level: try configured n first, then n-1, then n-2 (down to 1). auto lookup_draft = [&](int ngram, int K) -> std::vector { for (int n = ngram; n >= 1; n--) { auto d = lookup_one(n, K); if (!d.empty()) return d; } return {}; }; // Degeneration guard: classify a draft as repetition-induced so we can fall back to single // decode (and avoid PLD amplifying model's own repetition loop into a runaway "W W W …" mess). // Returns nullptr if draft is OK, else a short reason string for stats. auto draft_degenerate = [&](const std::vector& d) -> const char* { if (!args.pld_guard || d.empty()) return nullptr; // (1) distinct-token count: a draft of K tokens with < args.pld_guard_distinct distinct // values means n-gram is echoing a loop. Only apply when draft is long enough. if ((int)d.size() >= 3) { int distinct = 0; for (int i = 0; i < (int)d.size(); i++) { bool seen = false; for (int j = 0; j < i; j++) { if (d[j] == d[i]) { seen = true; break; } } if (!seen) distinct++; } if (distinct < args.pld_guard_distinct) return "low-distinct"; } // (2) tail echo: if the last N hist tokens are all equal to draft[0], the model is already // in a short loop — accepting the draft will just confirm the loop at batch speed. int tail_n = std::min(args.pld_guard_tail, (int)hist.size()); if (tail_n >= 3) { int matches = 0; for (int i = (int)hist.size() - tail_n; i < (int)hist.size(); i++) { if (hist[i] == d[0]) matches++; } if (matches == tail_n) return "tail-echo"; } return nullptr; }; // --- Decode loop --- auto t2 = std::chrono::steady_clock::now(); int pld_verifies = 0, pld_accepted = 0; int pld_rej_lowdist = 0, pld_rej_tailecho = 0; // guard rejection counters bool loop_warned = false; // warn-once state // Adaptive K state: recent accept counts for moving-average decisions const int ADAPT_WINDOW = 8; std::vector recent_accepts; int current_k = args.pld_k; bool pld_disabled_adapt = false; // set true when recent accept rate is too low to benefit while (st.decoded < args.n_predict - 1 && !st.hit_eos) { // Adaptive K: scale K with recent accept rate. // No auto-disable: since S=K+1 forward ≈ S=1 forward (latency-bound), even accept=0.1 // still nets slightly positive — PLD doesn't "hurt" as long as ngram lookup is cheap. if (args.pld_adaptive && (int)recent_accepts.size() >= ADAPT_WINDOW) { double avg = 0; for (int a : recent_accepts) avg += a; avg /= recent_accepts.size(); // Aim: K = 2*avg + 4 (generous window to catch upswings). Clamp [4, 12]. current_k = std::max(4, std::min(12, (int)std::round(2.0 * avg + 4.0))); } // Try PLD speculation path — skip until enough history accumulated std::vector draft; if (args.pld_enabled && (int)hist.size() >= args.pld_min_hist && is_master) { draft = lookup_draft(args.pld_ngram, current_k); // Degeneration guard: if draft looks like repetition-loop echo, drop it so this // iteration falls through to normal single decode. This does NOT stop a loop the model // is already in (greedy is deterministic), but it prevents PLD from running the loop // at batch speed while masquerading as a speedup. if (!draft.empty()) { const char* reason = draft_degenerate(draft); if (reason) { if (reason[0] == 'l') pld_rej_lowdist++; else pld_rej_tailecho++; draft.clear(); } } } // For TP>1, broadcast draft across ranks. Only broadcast if master has a non-empty draft; // otherwise all ranks take the no-draft path (normal decode). bool has_draft = is_master ? !draft.empty() : false; // Broadcast the "has_draft" flag (using a 1-element count: 1 = yes, 0 = no) if (args.tp_size > 1) { extern HcclCtx* runner_hccl_ctx_shim(Runner&); HcclCtx* ctx = runner_hccl_ctx_shim(runner); DeviceBuffer flag(4); int32_t f = has_draft ? 1 : 0; ACL_CHECK(aclrtMemcpy(flag.get(), 4, &f, 4, ACL_MEMCPY_HOST_TO_DEVICE)); hccl_broadcast(*ctx, flag.get(), 1, HCCL_DATA_TYPE_INT32, 0, runner.stream()); ACL_CHECK(aclrtSynchronizeStream(runner.stream())); ACL_CHECK(aclrtMemcpy(&f, 4, flag.get(), 4, ACL_MEMCPY_DEVICE_TO_HOST)); has_draft = (f != 0); if (has_draft) { std::vector d = draft; broadcast_token_ids(runner, d, args.max_seq, is_master); draft = d; } else { draft.clear(); } } if (args.pld_enabled && (int)draft.size() >= 1 && args.temperature == 0.0f) { // Batch verify: input = [next_id, draft[0], ..., draft[K-1]] std::vector batch_input = { next_id }; for (auto d : draft) batch_input.push_back(d); int S = (int)batch_input.size(); DeviceBuffer batch_logits; if (!runner.decode_batch(batch_input.data(), S, batch_logits)) break; std::vector blh(S * V); if (is_master) ACL_CHECK(aclrtMemcpy(blh.data(), S*V*2, batch_logits.get(), S*V*2, ACL_MEMCPY_DEVICE_TO_HOST)); // Accept longest prefix: draft[i] is "candidate" for position past+i+1. // blh row i predicts position past+i+1 (follows batch_input[i]). // Verify: blh[0].argmax == draft[0]? (i.e., does model agree with draft's first proposal) int accept = 0, new_next = next_id; if (is_master) { for (int i = 0; i < S - 1; i++) { int pred = 0; float bv = bf16_to_float(blh[i * V]); for (int k = 1; k < V; k++) { float v = bf16_to_float(blh[i*V + k]); if (v > bv) { bv = v; pred = k; } } if (pred == (int)draft[i]) accept++; else { new_next = pred; break; } } if (accept == S - 1) { // All draft accepted, bonus from last row int pred = 0; float bv = bf16_to_float(blh[(S-1) * V]); for (int k = 1; k < V; k++) { float v = bf16_to_float(blh[(S-1)*V + k]); if (v > bv) { bv = v; pred = k; } } new_next = pred; } } // Broadcast accept count + new_next across TP ranks if (args.tp_size > 1) { int32_t packed[2] = { (int32_t)accept, (int32_t)new_next }; std::vector p(packed, packed + 2); broadcast_token_ids(runner, p, args.max_seq, is_master); if (p.size() == 2) { accept = p[0]; new_next = p[1]; } } // Rewind KV for rejected drafts int64_t rewind = (int64_t)(S - 1 - accept); // drafts not accepted (excluding bonus) if (rewind > 0) runner.rewind_cache(rewind); // Commit accepted drafts + bonus to hist and emit for (int i = 0; i < accept; i++) { int tok = (int)draft[i]; hist.push_back(tok); generated.push_back(tok); st.decoded++; if (is_master && args.stream) { printf("%s", tokenizer.decode(tok).c_str()); fflush(stdout); } if (is_eos(tok)) { st.hit_eos = true; break; } } pld_verifies++; pld_accepted += accept; // Track recent accept for adaptive K if (args.pld_adaptive) { recent_accepts.push_back(accept); if ((int)recent_accepts.size() > ADAPT_WINDOW) recent_accepts.erase(recent_accepts.begin()); } if (st.hit_eos) break; // Bonus token (new_next) is also committed hist.push_back(new_next); generated.push_back(new_next); st.decoded++; if (is_master && args.stream) { printf("%s", tokenizer.decode(new_next).c_str()); fflush(stdout); } if (is_eos(new_next)) { st.hit_eos = true; break; } next_id = new_next; } else { // Normal decode DeviceBuffer logits2; if (!runner.decode((int32_t)next_id, logits2)) break; load_logits(logits2); next_id = sample_token(logits_h, V, args.temperature, args.top_k, args.top_p, rng); hist.push_back(next_id); generated.push_back(next_id); st.decoded++; if (is_master && args.stream) { printf("%s", tokenizer.decode(next_id).c_str()); fflush(stdout); } if (is_eos(next_id)) { st.hit_eos = true; break; } } // Loop-warn: emit a one-shot warning to stderr if the tail of generated is all-same-token. // Does not stop generation (user may want to see what happens) — just flags output quality. if (is_master && !loop_warned && args.pld_loop_warn > 0 && (int)generated.size() >= args.pld_loop_warn) { int tail = args.pld_loop_warn; int anchor = generated[(int)generated.size() - tail]; bool all_same = true; for (int i = (int)generated.size() - tail + 1; i < (int)generated.size(); i++) { if (generated[i] != anchor) { all_same = false; break; } } if (all_same) { fprintf(stderr, "\n[warn] %d consecutive identical tokens — likely degeneration loop; output after this point is suspect\n", tail); loop_warned = true; } } } auto t3 = std::chrono::steady_clock::now(); st.decode_ms = std::chrono::duration(t3 - t2).count(); if (is_master && args.pld_enabled) { if (pld_verifies > 0) { fprintf(stderr, "\n[pld] %d verifies, %d drafts accepted, avg=%.2f", pld_verifies, pld_accepted, (double)pld_accepted / pld_verifies); } else { fprintf(stderr, "\n[pld] 0 verifies (all drafts blocked or none found)"); } if (args.pld_guard && (pld_rej_lowdist + pld_rej_tailecho) > 0) { fprintf(stderr, "; guard rejections: low-distinct=%d tail-echo=%d", pld_rej_lowdist, pld_rej_tailecho); } fprintf(stderr, "\n"); } if (is_master) { if (args.stream) { printf("\n"); fflush(stdout); } else { std::string text = tokenizer.decode(generated); printf("%s%s\n", args.chat_template ? "" : prompt.c_str(), text.c_str()); } } return st; } static bool load_file(const std::string& path, std::string& out) { FILE* f = fopen(path.c_str(), "rb"); if (!f) { fprintf(stderr, "[cli] cannot open %s\n", path.c_str()); return false; } fseek(f, 0, SEEK_END); long sz = ftell(f); fseek(f, 0, SEEK_SET); out.resize(sz); size_t n = fread(out.data(), 1, sz, f); fclose(f); if ((long)n != sz) { fprintf(stderr, "[cli] short read from %s\n", path.c_str()); return false; } // Strip a single trailing newline (common in text files) if (!out.empty() && out.back() == '\n') out.pop_back(); return true; } int main(int argc, char** argv) { Args args; if (!parse_args(argc, argv, args)) return 1; // --prompt-file overrides --prompt if (!args.prompt_file.empty()) { if (!load_file(args.prompt_file, args.prompt)) return 1; } const bool is_master = (args.tp_rank == 0); std::mt19937 rng(args.seed ? args.seed : (uint64_t)std::chrono::steady_clock::now().time_since_epoch().count()); if (is_master) { printf("[cli] model=%s\n", args.model_dir.c_str()); printf("[cli] tp=%d n_predict=%d temp=%.2f top_k=%d top_p=%.2f chat=%d interactive=%d\n", args.tp_size, args.n_predict, args.temperature, args.top_k, args.top_p, args.chat_template, args.interactive); fflush(stdout); } Tokenizer tokenizer; if (!tokenizer.load(args.vocab_path)) { fprintf(stderr, "[cli] failed to load vocab %s\n", args.vocab_path.c_str()); return 1; } Runner runner; int num_layers = args.num_layers; if (num_layers == 0) { ModelConfig probe; if (!probe.load_from_json(args.model_dir + "/config.json")) return 1; num_layers = (int)probe.num_hidden_layers; } if (!runner.init(args.model_dir, args.tp_size, args.tp_rank, num_layers, args.max_seq, args.device_id)) return 1; if (const char* p = std::getenv("LCA_PROFILE"); p && std::atoi(p) != 0) { runner.profile_enabled = true; } // Warmup: cut cold-start latency. Controlled via LCA_WARMUP env (default 0 to keep behavior). if (const char* w = std::getenv("LCA_WARMUP"); w) { int n = std::atoi(w); if (n > 0) runner.warmup(n); } if (args.interactive) { const bool multi_turn = args.chat_template && !args.reset_each_turn; if (is_master) { printf("\n[cli] === interactive mode ===\n"); if (multi_turn) { printf("[cli] multi-turn chat (KV cache preserved). Commands: 'quit', 'reset'.\n"); if (!args.system_prompt.empty()) { printf("[cli] system: %s\n", args.system_prompt.c_str()); } } else { printf("[cli] stateless mode (KV cache reset each turn). Command: 'quit'.\n"); if (!args.chat_template) { printf("[cli] (hint: add --chat for multi-turn conversational memory)\n"); } } fflush(stdout); } // Conversation history: accumulated (role, content) pairs. System prompt seeded if present. std::vector> conversation; if (multi_turn && !args.system_prompt.empty()) { conversation.emplace_back("system", args.system_prompt); } auto* hccl_ctx = runner_hccl_ctx_shim(runner); // Signal types (broadcast as int32): 0 = normal turn, 1 = quit, 2 = reset. auto broadcast_signal = [&](int32_t sig)->int32_t { if (args.tp_size <= 1) return sig; DeviceBuffer s(4); ACL_CHECK(aclrtMemcpy(s.get(), 4, &sig, 4, ACL_MEMCPY_HOST_TO_DEVICE)); hccl_broadcast(*hccl_ctx, s.get(), 1, HCCL_DATA_TYPE_INT32, 0, runner.stream()); ACL_CHECK(aclrtSynchronizeStream(runner.stream())); int32_t r; ACL_CHECK(aclrtMemcpy(&r, 4, s.get(), 4, ACL_MEMCPY_DEVICE_TO_HOST)); return r; }; while (true) { std::string prompt; int32_t sig = 0; if (is_master) { printf("\n> "); fflush(stdout); if (!std::getline(std::cin, prompt)) sig = 1; else if (prompt == "quit" || prompt == "exit") sig = 1; else if (prompt == "reset") sig = 2; else if (prompt.empty()) sig = 3; // skip } sig = broadcast_signal(sig); if (sig == 1) break; if (sig == 2) { runner.reset_cache(); conversation.clear(); if (multi_turn && !args.system_prompt.empty()) conversation.emplace_back("system", args.system_prompt); if (is_master) { printf("[cli] (cache + conversation reset)\n"); fflush(stdout); } continue; } if (sig == 3) continue; TurnStats st; if (multi_turn) { // Append user message and tokenize full conversation. Prefill DELTA only. if (is_master) conversation.emplace_back("user", prompt); // Also ranks 1..N-1 need to track conversation (needed for correct delta count on // subsequent turns if TP ever tokenizes per-rank — currently rank 0 tokenizes). std::vector full_ids; if (is_master) { auto raw = tokenizer.encode_conversation_via_python(args.model_dir, conversation, /*gen_prompt=*/true); full_ids.reserve(raw.size()); for (int v : raw) full_ids.push_back((int32_t)v); } // Broadcast full_ids (variable-length). Use the same shim as broadcast_token_ids. if (args.tp_size > 1) { if (!broadcast_token_ids(runner, full_ids, args.max_seq, is_master)) break; } if (full_ids.empty()) { if (is_master) printf("[cli] tokenize failed\n"); continue; } int64_t past = runner.past_len(); if ((int64_t)full_ids.size() < past) { runner.reset_cache(); past = 0; } std::vector delta(full_ids.begin() + past, full_ids.end()); if (delta.empty()) { if (is_master) printf("[cli] (no new tokens)\n"); continue; } // Overflow check — simple policy: warn + auto-reset if the turn + generation // would exceed max_seq. Conversation history is cleared (except --system) so // the user's current prompt still fits. if ((int64_t)(past + delta.size()) + args.n_predict > args.max_seq) { if (is_master) { fprintf(stderr, "[cli] context %ld + gen %d > max_seq %d — auto-resetting\n", (long)(past + delta.size()), args.n_predict, args.max_seq); } runner.reset_cache(); // Rebuild conversation: keep only system + current user turn. if (is_master) { std::vector> fresh; for (auto& m : conversation) if (m.first == "system") fresh.push_back(m); if (!conversation.empty() && conversation.back().first == "user") { fresh.push_back(conversation.back()); } conversation = std::move(fresh); auto raw = tokenizer.encode_conversation_via_python(args.model_dir, conversation, true); full_ids.clear(); for (int v : raw) full_ids.push_back((int32_t)v); } if (args.tp_size > 1) { if (!broadcast_token_ids(runner, full_ids, args.max_seq, is_master)) break; } delta.assign(full_ids.begin(), full_ids.end()); past = 0; } // --- Prefill the delta --- st.n_prompt = (int)delta.size(); auto t0 = std::chrono::steady_clock::now(); DeviceBuffer logits; if (!runner.prefill(delta.data(), (int64_t)delta.size(), logits)) break; auto t1 = std::chrono::steady_clock::now(); st.prefill_ms = std::chrono::duration(t1 - t0).count(); const int64_t V = runner.cfg().vocab_size; std::vector logits_h(V); auto load_logits = [&](DeviceBuffer& buf) { ACL_CHECK(aclrtMemcpy(logits_h.data(), V*2, buf.get(), V*2, ACL_MEMCPY_DEVICE_TO_HOST)); }; auto is_eos = [&](int id) { for (int e : args.eos_ids) if (id == e) return true; return false; }; load_logits(logits); int next_id = sample_token(logits_h, V, args.temperature, args.top_k, args.top_p, rng); std::vector assistant_ids = { next_id }; if (is_master) { printf("%s", tokenizer.decode(next_id).c_str()); fflush(stdout); } st.hit_eos = is_eos(next_id); auto t2 = std::chrono::steady_clock::now(); for (int step = 1; step < args.n_predict && !st.hit_eos; step++) { DeviceBuffer logits2; if (!runner.decode((int32_t)next_id, logits2)) break; load_logits(logits2); next_id = sample_token(logits_h, V, args.temperature, args.top_k, args.top_p, rng); assistant_ids.push_back(next_id); st.decoded++; if (is_master) { printf("%s", tokenizer.decode(next_id).c_str()); fflush(stdout); } if (is_eos(next_id)) { st.hit_eos = true; break; } } auto t3 = std::chrono::steady_clock::now(); st.decode_ms = std::chrono::duration(t3 - t2).count(); if (is_master) { printf("\n"); fflush(stdout); } // Record assistant reply in conversation (strip trailing EOS before decode, // and trim incomplete UTF-8 tail if generation was cut mid-codepoint). if (is_master) { std::vector content_ids; for (int id : assistant_ids) { if (is_eos(id)) break; content_ids.push_back(id); } conversation.emplace_back("assistant", utf8_trim_incomplete(tokenizer.decode(content_ids))); } } else { // Stateless: reset cache, one-shot prompt runner.reset_cache(); st = run_turn(runner, tokenizer, args, prompt, rng, is_master); } if (is_master) { double tgs = (st.decode_ms > 0) ? (st.decoded * 1000.0 / st.decode_ms) : 0.0; printf("[perf] prefill %d tok %.0fms decode %d tok %.0fms = %.2f t/s%s past_len=%ld\n", st.n_prompt, st.prefill_ms, st.decoded, st.decode_ms, tgs, st.hit_eos ? " (EOS)" : "", runner.past_len()); fflush(stdout); } } if (is_master) printf("[cli] bye\n"); return 0; } // One-shot mode TurnStats st = run_turn(runner, tokenizer, args, args.prompt, rng, is_master); if (is_master) runner.print_profile_summary(); if (is_master) { if (st.hit_eos) printf("[cli] (hit EOS)\n"); printf("\n[perf] prefill: %.1fms for %d tokens = %.2f t/s\n", st.prefill_ms, st.n_prompt, (st.prefill_ms > 0) ? (st.n_prompt * 1000.0 / st.prefill_ms) : 0.0); if (st.decoded > 0) { printf("[perf] decode : %.1fms for %d tokens = %.2f t/s (TG)\n", st.decode_ms, st.decoded, (st.decoded * 1000.0) / st.decode_ms); } } return 0; }