Reinforcement Learning
Transformers
English
post-training
distillation
agentic-coding
composer-2.5
cursor
kimi-k2
grpo
dapo
diloco
openenv
trl
verl
research
methodology
Instructions to use Codeseys/composer-replication-framework with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Codeseys/composer-replication-framework with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Codeseys/composer-replication-framework", dtype="auto") - Notebooks
- Google Colab
- Kaggle
File size: 6,102 Bytes
e5add15 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 | """Numerical parity test against the upstream OPSD reference.
Loads `OPSDTrainer.generalized_jsd_loss` from a clone of siyan-zhao/OPSD at
/tmp/opsd-clone (override with $OPSD_CLONE) and asserts our re-implementation
in `composer_replication.opsd` matches it byte-for-byte across a grid of
shapes and β values. Skips cleanly when the upstream clone is absent.
Why this lives in `tests/` rather than docs: numerical parity is the
contract for this lift. If a future refactor of `generalized_jsd_loss`
silently shifts gradients again, this test fails immediately.
"""
from __future__ import annotations
import importlib.util
import os
import sys
from pathlib import Path
import pytest
import torch
from composer_replication.opsd import generalized_jsd_loss
# ----------------------------------------------------------------------
# Locate upstream OPSDTrainer.generalized_jsd_loss
# ----------------------------------------------------------------------
_OPSD_CLONE = Path(os.environ.get("OPSD_CLONE", "/tmp/opsd-clone"))
_OPSD_TRAINER_PATH = _OPSD_CLONE / "opsd_trainer.py"
def _load_upstream():
"""Import OPSDTrainer.generalized_jsd_loss from a local clone, isolated.
The upstream `opsd_trainer.py` imports heavyweight TRL / transformers
machinery at module scope, which we do not want to drag into the test
process. We instead extract the static method by parsing the source
text and exec-ing only that function body — it depends only on
`torch` and `torch.nn.functional`, which are already importable.
"""
if not _OPSD_TRAINER_PATH.exists():
return None
text = _OPSD_TRAINER_PATH.read_text()
# Pull out the function block. It starts with `def generalized_jsd_loss(`
# under `class OPSDTrainer` and ends at the next top-of-class `def `.
start = text.find("def generalized_jsd_loss(")
if start < 0:
return None
# Walk forward to the start of the next sibling method (4-space indent
# `def ` or class-end) — they all start with exactly 4 spaces of indent.
rest = text[start:]
# Skip past the function header and find the next `\n def ` or
# `\n @staticmethod` boundary.
end_marker_offsets = []
for marker in ("\n @", "\n def ", "\nclass "):
idx = rest.find(marker, len("def generalized_jsd_loss("))
if idx > 0:
end_marker_offsets.append(idx)
if not end_marker_offsets:
return None
fn_text = rest[: min(end_marker_offsets)]
# Dedent (the source lines are 4-space indented as a class method).
fn_text = "\n".join(
line[4:] if line.startswith(" ") else line for line in fn_text.splitlines()
)
# Exec into a fresh namespace with torch + F available.
import torch.nn.functional as F # noqa: F401 (used by exec'd code)
namespace: dict = {"torch": torch, "F": F}
exec(compile(fn_text, str(_OPSD_TRAINER_PATH), "exec"), namespace)
fn = namespace.get("generalized_jsd_loss")
return fn
_UPSTREAM_FN = _load_upstream()
_SKIP_REASON = (
f"upstream OPSD clone not found at {_OPSD_TRAINER_PATH} "
f"(set $OPSD_CLONE or `git clone --depth 1 https://github.com/siyan-zhao/OPSD {_OPSD_CLONE}`)"
)
# ----------------------------------------------------------------------
# Parity grid
# ----------------------------------------------------------------------
_SHAPES = [
(1, 4, 16),
(2, 8, 32),
(3, 5, 64),
(1, 16, 8),
(4, 3, 24),
]
_BETAS = [0.0, 0.5, 1.0]
@pytest.mark.skipif(_UPSTREAM_FN is None, reason=_SKIP_REASON)
@pytest.mark.parametrize("shape", _SHAPES)
@pytest.mark.parametrize("beta", _BETAS)
def test_parity_unmasked(shape, beta):
"""Our `generalized_jsd_loss` must match upstream within 1e-5 atol."""
B, T, V = shape
g = torch.Generator().manual_seed(13 + B * 31 + T * 17 + V)
student = torch.randn(B, T, V, generator=g, dtype=torch.float64)
teacher = torch.randn(B, T, V, generator=g, dtype=torch.float64)
ours = generalized_jsd_loss(student, teacher, beta=beta)
theirs = _UPSTREAM_FN(student, teacher, beta=beta) # type: ignore[misc]
assert torch.allclose(ours, theirs, atol=1e-5, rtol=1e-5), (
f"mismatch at shape={shape} beta={beta}: ours={ours.item()} theirs={theirs.item()}"
)
@pytest.mark.skipif(_UPSTREAM_FN is None, reason=_SKIP_REASON)
@pytest.mark.parametrize("shape", _SHAPES)
@pytest.mark.parametrize("beta", _BETAS)
def test_parity_masked(shape, beta):
"""Same parity but with a labels mask that ignores ~half the tokens."""
B, T, V = shape
g = torch.Generator().manual_seed(101 + B * 7 + T * 11 + V)
student = torch.randn(B, T, V, generator=g, dtype=torch.float64)
teacher = torch.randn(B, T, V, generator=g, dtype=torch.float64)
# Random valid/ignored mask: -100 for ignored, anything else for valid.
labels = torch.randint(0, 2, (B, T), generator=g)
labels = torch.where(labels == 0, torch.full_like(labels, -100), labels)
ours = generalized_jsd_loss(student, teacher, labels=labels, beta=beta)
theirs = _UPSTREAM_FN(student, teacher, labels=labels, beta=beta) # type: ignore[misc]
assert torch.allclose(ours, theirs, atol=1e-5, rtol=1e-5), (
f"mismatch at shape={shape} beta={beta}: ours={ours.item()} theirs={theirs.item()}"
)
@pytest.mark.skipif(_UPSTREAM_FN is None, reason=_SKIP_REASON)
def test_parity_temperature_and_topk():
"""Spot-check the temperature + top_k branches against upstream."""
g = torch.Generator().manual_seed(42)
student = torch.randn(2, 6, 32, generator=g, dtype=torch.float64)
teacher = torch.randn(2, 6, 32, generator=g, dtype=torch.float64)
for beta in (0.0, 0.3, 0.5, 0.7, 1.0):
ours = generalized_jsd_loss(student, teacher, beta=beta, temperature=2.0, top_k=8)
theirs = _UPSTREAM_FN( # type: ignore[misc]
student, teacher, beta=beta, temperature=2.0, top_k=8
)
assert torch.allclose(ours, theirs, atol=1e-5, rtol=1e-5), (
f"temp+topk parity failed at beta={beta}: ours={ours.item()} theirs={theirs.item()}"
)
|