File size: 4,435 Bytes
4b9fefd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
// test_batch_correctness.cpp — verify that forward with S>1 at past_len>0 produces the
// same logits at each position as sequential S=1 decodes.
//
// This is the foundation for speculative decoding / PLD: the main model must predict logits
// for each of K candidate positions in one batched forward pass matching sequential behavior.
#include "runner.h"

#include <cstdio>
#include <cstring>
#include <vector>
#include <cmath>

static float bf16_to_float(uint16_t x) {
    uint32_t u = (uint32_t)x << 16; float f; std::memcpy(&f, &u, 4); return f;
}

int main() {
    const std::string model_dir = "/path/to/Qwen3-235B-A22B-Instruct-2507-BF16";
    int tp_rank = 0, tp_size = 1;
    if (const char* v = std::getenv("TP_RANK")) tp_rank = std::atoi(v);
    if (const char* v = std::getenv("TP_SIZE")) tp_size = std::atoi(v);
    bool is_master = tp_rank == 0;

    Runner r;
    if (!r.init(model_dir, tp_size, tp_rank, 94, 512)) return 1;
    const int64_t V = r.cfg().vocab_size;

    // Prefix
    std::vector<int32_t> prompt = {785, 6722, 315, 9625, 374};
    DeviceBuffer logits0;
    r.prefill(prompt.data(), prompt.size(), logits0);
    std::vector<uint16_t> h_last0(V);
    if (is_master) ACL_CHECK(aclrtMemcpy(h_last0.data(), V*2, logits0.get(), V*2, ACL_MEMCPY_DEVICE_TO_HOST));
    int next0 = 0;
    if (is_master) {
        float best = -1e30; for (int i = 0; i < V; i++) { float v = bf16_to_float(h_last0[i]); if (v > best) { best = v; next0 = i; } }
    }
    // Broadcast next0 (simple: let rank 0 decide and non-master ranks independently too)
    int32_t token_seq[4];
    if (is_master) token_seq[0] = next0;

    // --- Path A: sequential S=1 decode × 4 times ---
    std::vector<std::vector<uint16_t>> seq_logits(4);
    for (int i = 0; i < 4; i++) seq_logits[i].resize(V);

    // first decode: takes prompt's last logit argmax
    // Here we need identical approach on all ranks. Use random token id for consistency.
    std::vector<int32_t> seq_tokens = {next0, 100, 200, 300};  // deterministic for test

    for (int i = 0; i < 4; i++) {
        DeviceBuffer out;
        r.decode(seq_tokens[i], out);
        if (is_master) ACL_CHECK(aclrtMemcpy(seq_logits[i].data(), V*2, out.get(), V*2, ACL_MEMCPY_DEVICE_TO_HOST));
    }
    int64_t past_after_seq = r.past_len();

    // --- Path B: reset, re-prefill, then ONE batch forward with S=4 ---
    r.reset_cache();
    DeviceBuffer logits_reprefill;
    r.prefill(prompt.data(), prompt.size(), logits_reprefill);

    DeviceBuffer batch_logits;
    r.prefill(seq_tokens.data(), 4, batch_logits);
    // prefill returns logits for LAST position only (S=4 gives [1, V], not [4, V]).
    // Hmm — that's a limitation. To do PLD we need logits for all 4 positions.
    // For now, just compare the LAST one (position 4 after prefix).

    std::vector<uint16_t> batch_last(V);
    if (is_master) ACL_CHECK(aclrtMemcpy(batch_last.data(), V*2, batch_logits.get(), V*2, ACL_MEMCPY_DEVICE_TO_HOST));

    if (is_master) {
        printf("\n=== Batched vs Sequential Decode Correctness ===\n");
        double l2d=0, l2r=0, maxd=0;
        for (int i = 0; i < V; i++) {
            float a = bf16_to_float(batch_last[i]), b = bf16_to_float(seq_logits[3][i]);
            l2d += (a-b)*(a-b); l2r += b*b;
            if (std::abs(a-b) > maxd) maxd = std::abs(a-b);
        }
        double rel = std::sqrt(l2d) / (std::sqrt(l2r) + 1e-10);
        printf("Last-position logits:\n");
        printf("  seq[3]  argmax = "); {
            int b = 0; float bv = bf16_to_float(seq_logits[3][0]);
            for (int i = 1; i < V; i++) if (bf16_to_float(seq_logits[3][i]) > bv) { bv = bf16_to_float(seq_logits[3][i]); b = i; }
            printf("%d (%.3f)\n", b, bv);
        }
        printf("  batch   argmax = "); {
            int b = 0; float bv = bf16_to_float(batch_last[0]);
            for (int i = 1; i < V; i++) if (bf16_to_float(batch_last[i]) > bv) { bv = bf16_to_float(batch_last[i]); b = i; }
            printf("%d (%.3f)\n", b, bv);
        }
        printf("  rel=%.4e  max=%.4f\n", rel, maxd);
        printf("  %s\n", rel < 5e-2 ? "PASS" : "FAIL (batch forward diverges from sequential)");
        printf("\nNote: current Runner.prefill() returns ONLY last-position logits. For PLD\n");
        printf("we need all-position logits: requires extending prefill to optionally output\n");
        printf("[S, V] logits tensor.\n");
    }
    return 0;
}