Instructions to use Codeseys/composer-replication-framework with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Codeseys/composer-replication-framework with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Codeseys/composer-replication-framework", dtype="auto") - Notebooks
- Google Colab
- Kaggle
Wave 21c: verify PRIME-RL adapter parity against upstream source (byte-for-byte)
Browse filesThe recipe's only automated parity check (test_parity_with_prime_rl_default_loss_fn)
is skip-marked whenever prime-rl isn't importable in the framework venv — i.e.
essentially always, since we deliberately keep prime-rl's heavy deps (vLLM,
pydantic config trees, flash-attn) out of our test env. The fallback was an
in-file reimplementation (_reference_default_loss), so a shared bug could pass
silently. This closes that gap against the ACTUAL upstream source.
## What was done
Built an out-of-band parity harness that:
- clones PrimeIntellect-ai/prime-rl (shallow)
- builds an isolated venv with ONLY torch+beartype+jaxtyping+numpy
- loads upstream src/prime_rl/trainer/rl/loss.py directly by path, stubbing the
two modules it imports (prime_rl.configs.trainer, prime_rl.utils.utils) so we
skip the vLLM/pydantic dependency tree entirely
- runs identical inputs through our loss_fn and upstream default_loss_fn
## Result
24/24 cases match (12 seeds × 2 regimes: tiny-perturbation + wide-divergence,
the latter exercising both DPPO masking branches), with partial loss masks.
**Max absolute difference 0.00e+00** — bit-identical, not merely within tolerance.
Upstream rev: f510ef6 (2026-05-28).
## Added
- composer_replication/recipes/prime_rl/verify_parity.sh: one-command reproducible
check (clones + isolated venv + sweep). Exit 0 = parity confirmed.
- _parity_harness.py: the sweep harness it runs.
- PARITY_VERIFIED.md: result + provenance + reproduce instructions.
- composer_loss.py docstring: notes parity is verified and that upstream
refactored the importance-ratio into compute_importance_ratio_and_mismatch_kl
(math unchanged); re-run verify_parity.sh after any upstream bump.
## Tests
Recipe unit tests: 15 passed, 1 skipped (the in-venv parity test still skips by
design — the out-of-band script is its reproducible counterpart).
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# PRIME-RL upstream parity — VERIFIED
|
| 2 |
+
|
| 3 |
+
**Status:** PASS ✅ — our adapter's Channel-1 loss matches PrimeIntellect-ai/prime-rl's
|
| 4 |
+
upstream `default_loss_fn` **byte-for-byte** (max absolute difference `0.00e+00`).
|
| 5 |
+
|
| 6 |
+
## What was verified
|
| 7 |
+
|
| 8 |
+
`composer_replication/recipes/prime_rl/composer_loss.py::loss_fn` (Channel 1:
|
| 9 |
+
DPPO + KL on the importance-sampling ratio, with the advantage-sign-conditioned
|
| 10 |
+
DPPO mask) produces numerically identical loss to upstream
|
| 11 |
+
`prime_rl.trainer.rl.loss.default_loss_fn` across:
|
| 12 |
+
|
| 13 |
+
- **12 random seeds × 2 regimes** (24 cases total)
|
| 14 |
+
- `tiny_perturb`: inference ≈ trainer + small noise → no DPPO masking (the
|
| 15 |
+
common on-policy regime)
|
| 16 |
+
- `wide_diff`: large trainer/inference divergence → exercises both the
|
| 17 |
+
`dppo_invalid_mask_high` (positive-advantage) and `dppo_invalid_mask_low`
|
| 18 |
+
(negative-advantage) branches hard
|
| 19 |
+
- partial loss masks (~10% of tokens masked out)
|
| 20 |
+
- PRIME-RL's own default config (`dppo_mask_low=0.2`, `dppo_mask_high=0.2`,
|
| 21 |
+
`adv_tau=1.0`, `kl_tau=1e-3`)
|
| 22 |
+
|
| 23 |
+
Result: **24/24 exact matches**, max abs diff `0.00e+00` (not merely within
|
| 24 |
+
`atol=1e-5` — bit-identical for these inputs).
|
| 25 |
+
|
| 26 |
+
## Provenance
|
| 27 |
+
|
| 28 |
+
- Upstream: `PrimeIntellect-ai/prime-rl` @ `f510ef6` (2026-05-28)
|
| 29 |
+
- Verified by loading upstream `src/prime_rl/trainer/rl/loss.py` directly by path
|
| 30 |
+
in an isolated venv (torch+beartype+jaxtyping+numpy only — no vLLM, no pydantic
|
| 31 |
+
config tree), with `prime_rl.configs.trainer` / `prime_rl.utils.utils` stubbed.
|
| 32 |
+
- Reproduce: `bash composer_replication/recipes/prime_rl/verify_parity.sh`
|
| 33 |
+
|
| 34 |
+
## Why this matters
|
| 35 |
+
|
| 36 |
+
Previously the only automated check was `test_parity_with_prime_rl_default_loss_fn`,
|
| 37 |
+
which is skip-marked whenever prime-rl isn't importable in the framework venv —
|
| 38 |
+
i.e. essentially always, because we deliberately keep prime-rl's heavy deps out of
|
| 39 |
+
our test env. The fallback `_reference_default_loss` in the unit tests is an *in-file
|
| 40 |
+
reimplementation*, so a shared bug between it and `loss_fn` would pass silently.
|
| 41 |
+
This out-of-band check closes that gap against the **actual upstream source**.
|
| 42 |
+
|
| 43 |
+
## Note on upstream drift
|
| 44 |
+
|
| 45 |
+
Upstream refactored the importance-ratio computation into a helper
|
| 46 |
+
(`compute_importance_ratio_and_mismatch_kl`) since the line-references in
|
| 47 |
+
`composer_loss.py`'s docstring were written. The **math is unchanged** — the helper
|
| 48 |
+
just extracts `log_importance_ratio / importance_ratio / mismatch_kl`. Our adapter
|
| 49 |
+
remains exact against current `f510ef6`. Re-run `verify_parity.sh` after any
|
| 50 |
+
upstream bump to catch a real divergence early.
|
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Isolated PRIME-RL parity harness — runs OUR adapter vs UPSTREAM default_loss_fn
|
| 2 |
+
byte-for-byte, without installing the full prime-rl package (which drags vLLM,
|
| 3 |
+
pydantic config trees, etc.).
|
| 4 |
+
|
| 5 |
+
Strategy: stub the two modules upstream loss.py imports (`prime_rl.configs.trainer`
|
| 6 |
+
for DefaultLossConfig + CustomLossConfig + LossConfig, and `prime_rl.utils.utils`
|
| 7 |
+
for import_object), then load loss.py by file path. Compare on random inputs.
|
| 8 |
+
|
| 9 |
+
Run with the throwaway venv that has torch+beartype+jaxtyping+numpy:
|
| 10 |
+
/tmp/prime-parity-venv/bin/python this_file.py /path/to/prime-rl /path/to/framework
|
| 11 |
+
"""
|
| 12 |
+
import importlib.util
|
| 13 |
+
import sys
|
| 14 |
+
import types
|
| 15 |
+
from dataclasses import dataclass
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
|
| 20 |
+
PRIME_RL = Path(sys.argv[1])
|
| 21 |
+
FRAMEWORK = Path(sys.argv[2])
|
| 22 |
+
|
| 23 |
+
# --- Stub the config + utils modules loss.py needs at import time -----------
|
| 24 |
+
cfg_mod = types.ModuleType("prime_rl.configs.trainer")
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@dataclass
|
| 28 |
+
class DefaultLossConfig:
|
| 29 |
+
# Exact upstream defaults (trainer.py lines 412-425).
|
| 30 |
+
dppo_mask_low: float = 0.2
|
| 31 |
+
dppo_mask_high: float = 0.2
|
| 32 |
+
adv_tau: float = 1.0
|
| 33 |
+
kl_tau: float = 1e-3
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class CustomLossConfig: # only referenced in type hints / isinstance paths
|
| 37 |
+
pass
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class LossConfig:
|
| 41 |
+
pass
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
cfg_mod.DefaultLossConfig = DefaultLossConfig
|
| 45 |
+
cfg_mod.CustomLossConfig = CustomLossConfig
|
| 46 |
+
cfg_mod.LossConfig = LossConfig
|
| 47 |
+
|
| 48 |
+
utils_mod = types.ModuleType("prime_rl.utils.utils")
|
| 49 |
+
utils_mod.import_object = lambda path: None # unused by default_loss_fn
|
| 50 |
+
|
| 51 |
+
# Register stub package tree so `from prime_rl.configs.trainer import ...` resolves.
|
| 52 |
+
for name in ("prime_rl", "prime_rl.configs", "prime_rl.utils"):
|
| 53 |
+
sys.modules.setdefault(name, types.ModuleType(name))
|
| 54 |
+
sys.modules["prime_rl.configs.trainer"] = cfg_mod
|
| 55 |
+
sys.modules["prime_rl.utils.utils"] = utils_mod
|
| 56 |
+
|
| 57 |
+
# --- Load upstream loss.py by path ------------------------------------------
|
| 58 |
+
loss_path = PRIME_RL / "src" / "prime_rl" / "trainer" / "rl" / "loss.py"
|
| 59 |
+
spec = importlib.util.spec_from_file_location("prime_rl.trainer.rl.loss", loss_path)
|
| 60 |
+
upstream = importlib.util.module_from_spec(spec)
|
| 61 |
+
sys.modules["prime_rl.trainer.rl.loss"] = upstream
|
| 62 |
+
spec.loader.exec_module(upstream)
|
| 63 |
+
print(f"loaded upstream loss.py from {loss_path}")
|
| 64 |
+
|
| 65 |
+
# --- Load our adapter -------------------------------------------------------
|
| 66 |
+
sys.path.insert(0, str(FRAMEWORK))
|
| 67 |
+
from composer_replication.recipes.prime_rl.composer_loss import loss_fn as ours # noqa: E402
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
@dataclass
|
| 71 |
+
class FakeLossInputs:
|
| 72 |
+
trainer_logprobs: torch.Tensor
|
| 73 |
+
inference_logprobs: torch.Tensor
|
| 74 |
+
teacher_logprobs: object
|
| 75 |
+
advantages: torch.Tensor
|
| 76 |
+
loss_mask: torch.Tensor
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
# --- Parity sweep across seeds + regimes ------------------------------------
|
| 80 |
+
cfg = DefaultLossConfig()
|
| 81 |
+
n_pass = 0
|
| 82 |
+
n_total = 0
|
| 83 |
+
max_abs_diff = 0.0
|
| 84 |
+
for seed in range(12):
|
| 85 |
+
for regime in ("tiny_perturb", "wide_diff"):
|
| 86 |
+
g = torch.Generator().manual_seed(seed)
|
| 87 |
+
seq = 32
|
| 88 |
+
trainer_lp = -(0.1 + 2.0 * torch.rand(seq, generator=g)).to(torch.float32)
|
| 89 |
+
if regime == "tiny_perturb":
|
| 90 |
+
inference_lp = (trainer_lp + 0.05 * torch.randn(seq, generator=g)).to(torch.float32)
|
| 91 |
+
else:
|
| 92 |
+
# Large divergence -> exercises the DPPO masking branches hard.
|
| 93 |
+
inference_lp = -(0.1 + 2.0 * torch.rand(seq, generator=g)).to(torch.float32)
|
| 94 |
+
advantages = torch.randn(seq, generator=g, dtype=torch.float32)
|
| 95 |
+
loss_mask = (torch.rand(seq, generator=g) > 0.1) # ~10% masked out
|
| 96 |
+
|
| 97 |
+
up_inputs = upstream.LossInputs(
|
| 98 |
+
trainer_logprobs=trainer_lp,
|
| 99 |
+
inference_logprobs=inference_lp,
|
| 100 |
+
teacher_logprobs=None,
|
| 101 |
+
advantages=advantages,
|
| 102 |
+
loss_mask=loss_mask,
|
| 103 |
+
)
|
| 104 |
+
up_out = upstream.default_loss_fn(up_inputs, cfg)
|
| 105 |
+
|
| 106 |
+
our_out = ours(
|
| 107 |
+
FakeLossInputs(
|
| 108 |
+
trainer_logprobs=trainer_lp.clone(),
|
| 109 |
+
inference_logprobs=inference_lp.clone(),
|
| 110 |
+
teacher_logprobs=None,
|
| 111 |
+
advantages=advantages.clone(),
|
| 112 |
+
loss_mask=loss_mask.clone(),
|
| 113 |
+
),
|
| 114 |
+
alpha_sdpo=0.0,
|
| 115 |
+
beta_dpo=0.0,
|
| 116 |
+
dppo_mask_high=cfg.dppo_mask_high,
|
| 117 |
+
dppo_mask_low=cfg.dppo_mask_low,
|
| 118 |
+
adv_tau=cfg.adv_tau,
|
| 119 |
+
kl_tau=cfg.kl_tau,
|
| 120 |
+
)
|
| 121 |
+
our_loss = our_out.loss if hasattr(our_out, "loss") else our_out
|
| 122 |
+
diff = abs(float(our_loss) - float(up_out.loss))
|
| 123 |
+
max_abs_diff = max(max_abs_diff, diff)
|
| 124 |
+
ok = torch.isclose(our_loss, up_out.loss, atol=1e-5, rtol=1e-5).item()
|
| 125 |
+
n_total += 1
|
| 126 |
+
n_pass += int(ok)
|
| 127 |
+
if not ok:
|
| 128 |
+
print(f" MISMATCH seed={seed} {regime}: ours={float(our_loss):.6f} up={float(up_out.loss):.6f} diff={diff:.2e}")
|
| 129 |
+
|
| 130 |
+
print(f"\nPARITY: {n_pass}/{n_total} cases match upstream (max abs diff {max_abs_diff:.2e})")
|
| 131 |
+
print("RESULT:", "PASS ✅" if n_pass == n_total else "FAIL ❌")
|
| 132 |
+
sys.exit(0 if n_pass == n_total else 1)
|
|
@@ -85,6 +85,14 @@ divides by ``loss_scale``); we mirror that.
|
|
| 85 |
|
| 86 |
License: MIT (matches the rest of the framework). PRIME-RL is Apache-2;
|
| 87 |
we reference its algorithm and convention but vendor no code.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
"""
|
| 89 |
from __future__ import annotations
|
| 90 |
|
|
|
|
| 85 |
|
| 86 |
License: MIT (matches the rest of the framework). PRIME-RL is Apache-2;
|
| 87 |
we reference its algorithm and convention but vendor no code.
|
| 88 |
+
|
| 89 |
+
Upstream parity: VERIFIED byte-for-byte (max abs diff 0.00e+00) against
|
| 90 |
+
PrimeIntellect-ai/prime-rl @ f510ef6 across 24 cases. See
|
| 91 |
+
``PARITY_VERIFIED.md`` and reproduce with ``verify_parity.sh`` (isolated venv,
|
| 92 |
+
no vLLM/pydantic deps). Upstream has since refactored the importance-ratio into
|
| 93 |
+
``compute_importance_ratio_and_mismatch_kl`` — the line-references above predate
|
| 94 |
+
that extraction but the math is unchanged; re-run verify_parity.sh after any
|
| 95 |
+
upstream bump.
|
| 96 |
"""
|
| 97 |
from __future__ import annotations
|
| 98 |
|
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# Verify our PRIME-RL composer-loss adapter matches UPSTREAM default_loss_fn
|
| 3 |
+
# byte-for-byte, WITHOUT installing the full prime-rl package (which pulls vLLM,
|
| 4 |
+
# pydantic config trees, flash-attn, etc.). We clone prime-rl, build a throwaway
|
| 5 |
+
# venv with only torch+beartype+jaxtyping+numpy, load upstream loss.py by path
|
| 6 |
+
# with stubbed config/utils modules, and run identical inputs through both.
|
| 7 |
+
#
|
| 8 |
+
# Usage:
|
| 9 |
+
# bash composer_replication/recipes/prime_rl/verify_parity.sh
|
| 10 |
+
#
|
| 11 |
+
# Exit 0 = byte-for-byte parity confirmed; non-zero = mismatch or setup failure.
|
| 12 |
+
#
|
| 13 |
+
# This is the reproducible counterpart to the skip-marked
|
| 14 |
+
# test_parity_with_prime_rl_default_loss_fn unit test: that test only runs when
|
| 15 |
+
# prime-rl is importable in the framework venv (it usually isn't, by design —
|
| 16 |
+
# we don't want prime-rl's heavy deps in our test env). This script provides the
|
| 17 |
+
# real upstream check out-of-band.
|
| 18 |
+
set -euo pipefail
|
| 19 |
+
|
| 20 |
+
PRIME_RL_REPO="${PRIME_RL_REPO:-https://github.com/PrimeIntellect-ai/prime-rl.git}"
|
| 21 |
+
WORK="${WORK:-/tmp/prime-rl-parity-check}"
|
| 22 |
+
FRAMEWORK="$(cd "$(dirname "${BASH_SOURCE[0]}")/../../.." && pwd)"
|
| 23 |
+
CLONE="$WORK/prime-rl"
|
| 24 |
+
VENV="$WORK/venv"
|
| 25 |
+
HARNESS="$WORK/harness.py"
|
| 26 |
+
|
| 27 |
+
mkdir -p "$WORK"
|
| 28 |
+
|
| 29 |
+
echo "==> Cloning prime-rl (shallow) into $CLONE"
|
| 30 |
+
if [ ! -d "$CLONE/.git" ]; then
|
| 31 |
+
git clone --depth 1 "$PRIME_RL_REPO" "$CLONE"
|
| 32 |
+
fi
|
| 33 |
+
PRIME_REV="$(cd "$CLONE" && git rev-parse --short HEAD)"
|
| 34 |
+
echo " upstream rev: $PRIME_REV"
|
| 35 |
+
|
| 36 |
+
echo "==> Building isolated venv (torch+beartype+jaxtyping+numpy only)"
|
| 37 |
+
if [ ! -x "$VENV/bin/python" ]; then
|
| 38 |
+
python3 -m venv "$VENV"
|
| 39 |
+
"$VENV/bin/pip" install --quiet --upgrade pip
|
| 40 |
+
# CPU torch is plenty for a loss-numerics parity check.
|
| 41 |
+
"$VENV/bin/pip" install --quiet torch --index-url https://download.pytorch.org/whl/cpu
|
| 42 |
+
"$VENV/bin/pip" install --quiet beartype jaxtyping numpy
|
| 43 |
+
fi
|
| 44 |
+
|
| 45 |
+
echo "==> Writing parity harness"
|
| 46 |
+
cp "$FRAMEWORK/composer_replication/recipes/prime_rl/_parity_harness.py" "$HARNESS"
|
| 47 |
+
|
| 48 |
+
echo "==> Running parity sweep"
|
| 49 |
+
"$VENV/bin/python" "$HARNESS" "$CLONE" "$FRAMEWORK"
|