// test_batch_decode.cpp — benchmark decode with different batch sizes S = 1, 2, 4, 8. // // Purpose: quantify the cost of "batched decode" (a.k.a. the ingredient speculative decoding // relies on). If Runner.prefill(S=K) forward-pass is only a small overhead over S=1, then // spec-decoding with K draft tokens gives ~K× speedup at high accept rate. // // Per-token amortized cost: // cost(S) / S // Speculative decoding benefit: // expected_accept_rate * K = effective tokens per forward // → TG = expected / (cost(S=K+1) / 1_sec) #include "runner.h" #include #include #include #include int main() { const std::string model_dir = "/path/to/Qwen3-235B-A22B-Instruct-2507-BF16"; Runner r; int tp_rank = 0, tp_size = 1; if (const char* v = std::getenv("TP_RANK")) tp_rank = std::atoi(v); if (const char* v = std::getenv("TP_SIZE")) tp_size = std::atoi(v); bool is_master = tp_rank == 0; if (!r.init(model_dir, tp_size, tp_rank, /*num_layers=*/94, /*max_seq=*/512)) return 1; // Prefill a short context so decode has some KV cache std::vector prompt = {785, 6722, 315, 9625, 374}; // "The capital of France is" DeviceBuffer logits; r.prefill(prompt.data(), prompt.size(), logits); auto now = []() { return std::chrono::steady_clock::now(); }; auto ms = [](auto t0, auto t1) { return std::chrono::duration(t1 - t0).count(); }; std::vector batch_sizes = {1, 2, 4, 8}; int N_ITERS = 20; if (is_master) { printf("\n=== Batched decode forward benchmark (94 layers, TP=%d) ===\n", tp_size); printf("Each row: forward with S=K new tokens after prefill\n"); printf("%-5s %-12s %-18s %-18s %s\n", "S", "ms/forward", "ms/token (amort)", "tokens/sec", "vs S=1 efficiency"); } double base_per_token = 0; for (int S : batch_sizes) { // Reset cache between measurements to keep cache size fair (same position for each) // Actually we want to simulate: after some past_len, do 1 forward with S new tokens. // Use prefill() which accepts S>=1. std::vector times; for (int iter = 0; iter < N_ITERS + 3; iter++) { // +3 for warmup r.reset_cache(); r.prefill(prompt.data(), prompt.size(), logits); // re-prefill // New forward with S tokens (as if doing speculative verify) std::vector new_tokens(S, 100); // dummy token ids auto t0 = now(); DeviceBuffer logits2; r.prefill(new_tokens.data(), S, logits2); auto t1 = now(); if (iter >= 3) times.push_back(ms(t0, t1)); } std::sort(times.begin(), times.end()); double median_ms = times[times.size() / 2]; double per_token = median_ms / S; double tok_per_sec = 1000.0 / per_token; if (S == 1) base_per_token = per_token; double efficiency = base_per_token / per_token * 100.0; if (is_master) { printf("%-5d %-12.2f %-18.2f %-18.2f %.1f%%\n", S, median_ms, per_token, tok_per_sec, efficiency); } } if (is_master) { printf("\n=== Interpretation ===\n"); printf("If S=4 forward ~ S=1 (efficiency high), spec decoding with accept_rate=70%%\n"); printf("gives TG = 0.7*4 / cost(S=5) vs baseline 1 / cost(S=1) = up to 2.8× speedup.\n"); } return 0; }