"""Composer Replication Framework — quickstart smoke. Runs the same 5-step CPU smoke as Spike 006, but using the installed package API instead of importing from the spike directory. Usage: cd composer-replication-framework pip install -e . python examples/qwen_05b_quickstart/run.py Expected: loss decreases from ~0.7 to <0.01 over 5 backward steps; all gradients finite; ~3-5 min wall-clock on CPU; ~1 GB disk for Qwen2.5-0.5B weights (downloaded once into HF cache). """ from __future__ import annotations import sys import torch # After `pip install -e .` from repo root, this import resolves cleanly. from composer_replication import build_batch, compose_loss MODEL_REPO = "Qwen/Qwen2.5-0.5B-Instruct" def main() -> int: print(f"[quickstart] loading {MODEL_REPO} (CPU, fp32) ...") from transformers import AutoModelForCausalLM, AutoTokenizer # Pin RNG state for reproducibility. Without this the per-step numbers # printed below would shift between runs (e.g. the dummy ref logprobs # used by the DPO channel feed back into the random init of params via # backward, so even tiny RNG perturbations move the loss curve). import random random.seed(42) torch.manual_seed(42) tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO) model = AutoModelForCausalLM.from_pretrained(MODEL_REPO, torch_dtype=torch.float32) model = model.to("cpu") model.train() n_params_b = sum(p.numel() for p in model.parameters()) / 1e9 print(f"[quickstart] loaded — {n_params_b:.3f}B params") print("[quickstart] building real chat-template batch ...") batch = build_batch(tokenizer, device="cpu") optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5) print("[quickstart] running 5 backward steps ...") losses: list[float] = [] for step in range(5): optimizer.zero_grad() components = compose_loss(model, batch, alpha_sdpo=0.1, beta_replay=0.05) components.total.backward() # Verify finite grads finite = all( (p.grad is None or torch.isfinite(p.grad).all().item()) for p in model.parameters() ) optimizer.step() c = components.detached() losses.append(c["total"]) print( f" step {step}: total={c['total']:.4f} " f"lm_ce={c['lm_ce']:.4f} " f"sdpo={c['sdpo_jsd']:.4f} " f"dpo={c['trace_replay_dpo']:.4f} " f"finite={finite}" ) initial, final = losses[0], losses[-1] decreased = final < initial print() print("=" * 56) print(f" Initial loss: {initial:.4f}") print(f" Final loss: {final:.4f}") print(f" Reduction: {(1 - final / initial) * 100:.1f}%") print(f" Verdict: {'PASS' if decreased else 'FAIL'}") print("=" * 56) return 0 if decreased else 1 if __name__ == "__main__": sys.exit(main())