// test_batch_correctness.cpp — verify that forward with S>1 at past_len>0 produces the // same logits at each position as sequential S=1 decodes. // // This is the foundation for speculative decoding / PLD: the main model must predict logits // for each of K candidate positions in one batched forward pass matching sequential behavior. #include "runner.h" #include #include #include #include static float bf16_to_float(uint16_t x) { uint32_t u = (uint32_t)x << 16; float f; std::memcpy(&f, &u, 4); return f; } int main() { const std::string model_dir = "/path/to/Qwen3-235B-A22B-Instruct-2507-BF16"; int tp_rank = 0, tp_size = 1; if (const char* v = std::getenv("TP_RANK")) tp_rank = std::atoi(v); if (const char* v = std::getenv("TP_SIZE")) tp_size = std::atoi(v); bool is_master = tp_rank == 0; Runner r; if (!r.init(model_dir, tp_size, tp_rank, 94, 512)) return 1; const int64_t V = r.cfg().vocab_size; // Prefix std::vector prompt = {785, 6722, 315, 9625, 374}; DeviceBuffer logits0; r.prefill(prompt.data(), prompt.size(), logits0); std::vector h_last0(V); if (is_master) ACL_CHECK(aclrtMemcpy(h_last0.data(), V*2, logits0.get(), V*2, ACL_MEMCPY_DEVICE_TO_HOST)); int next0 = 0; if (is_master) { float best = -1e30; for (int i = 0; i < V; i++) { float v = bf16_to_float(h_last0[i]); if (v > best) { best = v; next0 = i; } } } // Broadcast next0 (simple: let rank 0 decide and non-master ranks independently too) int32_t token_seq[4]; if (is_master) token_seq[0] = next0; // --- Path A: sequential S=1 decode × 4 times --- std::vector> seq_logits(4); for (int i = 0; i < 4; i++) seq_logits[i].resize(V); // first decode: takes prompt's last logit argmax // Here we need identical approach on all ranks. Use random token id for consistency. std::vector seq_tokens = {next0, 100, 200, 300}; // deterministic for test for (int i = 0; i < 4; i++) { DeviceBuffer out; r.decode(seq_tokens[i], out); if (is_master) ACL_CHECK(aclrtMemcpy(seq_logits[i].data(), V*2, out.get(), V*2, ACL_MEMCPY_DEVICE_TO_HOST)); } int64_t past_after_seq = r.past_len(); // --- Path B: reset, re-prefill, then ONE batch forward with S=4 --- r.reset_cache(); DeviceBuffer logits_reprefill; r.prefill(prompt.data(), prompt.size(), logits_reprefill); DeviceBuffer batch_logits; r.prefill(seq_tokens.data(), 4, batch_logits); // prefill returns logits for LAST position only (S=4 gives [1, V], not [4, V]). // Hmm — that's a limitation. To do PLD we need logits for all 4 positions. // For now, just compare the LAST one (position 4 after prefix). std::vector batch_last(V); if (is_master) ACL_CHECK(aclrtMemcpy(batch_last.data(), V*2, batch_logits.get(), V*2, ACL_MEMCPY_DEVICE_TO_HOST)); if (is_master) { printf("\n=== Batched vs Sequential Decode Correctness ===\n"); double l2d=0, l2r=0, maxd=0; for (int i = 0; i < V; i++) { float a = bf16_to_float(batch_last[i]), b = bf16_to_float(seq_logits[3][i]); l2d += (a-b)*(a-b); l2r += b*b; if (std::abs(a-b) > maxd) maxd = std::abs(a-b); } double rel = std::sqrt(l2d) / (std::sqrt(l2r) + 1e-10); printf("Last-position logits:\n"); printf(" seq[3] argmax = "); { int b = 0; float bv = bf16_to_float(seq_logits[3][0]); for (int i = 1; i < V; i++) if (bf16_to_float(seq_logits[3][i]) > bv) { bv = bf16_to_float(seq_logits[3][i]); b = i; } printf("%d (%.3f)\n", b, bv); } printf(" batch argmax = "); { int b = 0; float bv = bf16_to_float(batch_last[0]); for (int i = 1; i < V; i++) if (bf16_to_float(batch_last[i]) > bv) { bv = bf16_to_float(batch_last[i]); b = i; } printf("%d (%.3f)\n", b, bv); } printf(" rel=%.4e max=%.4f\n", rel, maxd); printf(" %s\n", rel < 5e-2 ? "PASS" : "FAIL (batch forward diverges from sequential)"); printf("\nNote: current Runner.prefill() returns ONLY last-position logits. For PLD\n"); printf("we need all-position logits: requires extending prefill to optionally output\n"); printf("[S, V] logits tensor.\n"); } return 0; }