File size: 2,930 Bytes
ac05fbf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d88715c
 
 
 
 
 
 
 
ac05fbf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
"""Composer Replication Framework — quickstart smoke.

Runs the same 5-step CPU smoke as Spike 006, but using the installed package
API instead of importing from the spike directory.

Usage:
    cd composer-replication-framework
    pip install -e .
    python examples/qwen_05b_quickstart/run.py

Expected: loss decreases from ~0.7 to <0.01 over 5 backward steps; all
gradients finite; ~3-5 min wall-clock on CPU; ~1 GB disk for Qwen2.5-0.5B
weights (downloaded once into HF cache).
"""
from __future__ import annotations

import sys

import torch

# After `pip install -e .` from repo root, this import resolves cleanly.
from composer_replication import build_batch, compose_loss


MODEL_REPO = "Qwen/Qwen2.5-0.5B-Instruct"


def main() -> int:
    print(f"[quickstart] loading {MODEL_REPO} (CPU, fp32) ...")
    from transformers import AutoModelForCausalLM, AutoTokenizer

    # Pin RNG state for reproducibility. Without this the per-step numbers
    # printed below would shift between runs (e.g. the dummy ref logprobs
    # used by the DPO channel feed back into the random init of params via
    # backward, so even tiny RNG perturbations move the loss curve).
    import random
    random.seed(42)
    torch.manual_seed(42)

    tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO)
    model = AutoModelForCausalLM.from_pretrained(MODEL_REPO, torch_dtype=torch.float32)
    model = model.to("cpu")
    model.train()
    n_params_b = sum(p.numel() for p in model.parameters()) / 1e9
    print(f"[quickstart] loaded — {n_params_b:.3f}B params")

    print("[quickstart] building real chat-template batch ...")
    batch = build_batch(tokenizer, device="cpu")

    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)

    print("[quickstart] running 5 backward steps ...")
    losses: list[float] = []
    for step in range(5):
        optimizer.zero_grad()
        components = compose_loss(model, batch, alpha_sdpo=0.1, beta_replay=0.05)
        components.total.backward()

        # Verify finite grads
        finite = all(
            (p.grad is None or torch.isfinite(p.grad).all().item())
            for p in model.parameters()
        )

        optimizer.step()

        c = components.detached()
        losses.append(c["total"])
        print(
            f"  step {step}: total={c['total']:.4f}  "
            f"lm_ce={c['lm_ce']:.4f}  "
            f"sdpo={c['sdpo_jsd']:.4f}  "
            f"dpo={c['trace_replay_dpo']:.4f}  "
            f"finite={finite}"
        )

    initial, final = losses[0], losses[-1]
    decreased = final < initial
    print()
    print("=" * 56)
    print(f"  Initial loss: {initial:.4f}")
    print(f"  Final loss:   {final:.4f}")
    print(f"  Reduction:    {(1 - final / initial) * 100:.1f}%")
    print(f"  Verdict:      {'PASS' if decreased else 'FAIL'}")
    print("=" * 56)

    return 0 if decreased else 1


if __name__ == "__main__":
    sys.exit(main())