Reinforcement Learning
Transformers
English
post-training
distillation
agentic-coding
composer-2.5
cursor
kimi-k2
grpo
dapo
diloco
openenv
trl
verl
research
methodology
Instructions to use Codeseys/composer-replication-framework with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Codeseys/composer-replication-framework with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Codeseys/composer-replication-framework", dtype="auto") - Notebooks
- Google Colab
- Kaggle
| """Numerical parity test against the upstream OPSD reference. | |
| Loads `OPSDTrainer.generalized_jsd_loss` from a clone of siyan-zhao/OPSD at | |
| /tmp/opsd-clone (override with $OPSD_CLONE) and asserts our re-implementation | |
| in `composer_replication.opsd` matches it byte-for-byte across a grid of | |
| shapes and β values. Skips cleanly when the upstream clone is absent. | |
| Why this lives in `tests/` rather than docs: numerical parity is the | |
| contract for this lift. If a future refactor of `generalized_jsd_loss` | |
| silently shifts gradients again, this test fails immediately. | |
| """ | |
| from __future__ import annotations | |
| import importlib.util | |
| import os | |
| import sys | |
| from pathlib import Path | |
| import pytest | |
| import torch | |
| from composer_replication.opsd import generalized_jsd_loss | |
| # ---------------------------------------------------------------------- | |
| # Locate upstream OPSDTrainer.generalized_jsd_loss | |
| # ---------------------------------------------------------------------- | |
| _OPSD_CLONE = Path(os.environ.get("OPSD_CLONE", "/tmp/opsd-clone")) | |
| _OPSD_TRAINER_PATH = _OPSD_CLONE / "opsd_trainer.py" | |
| def _load_upstream(): | |
| """Import OPSDTrainer.generalized_jsd_loss from a local clone, isolated. | |
| The upstream `opsd_trainer.py` imports heavyweight TRL / transformers | |
| machinery at module scope, which we do not want to drag into the test | |
| process. We instead extract the static method by parsing the source | |
| text and exec-ing only that function body — it depends only on | |
| `torch` and `torch.nn.functional`, which are already importable. | |
| """ | |
| if not _OPSD_TRAINER_PATH.exists(): | |
| return None | |
| text = _OPSD_TRAINER_PATH.read_text() | |
| # Pull out the function block. It starts with `def generalized_jsd_loss(` | |
| # under `class OPSDTrainer` and ends at the next top-of-class `def `. | |
| start = text.find("def generalized_jsd_loss(") | |
| if start < 0: | |
| return None | |
| # Walk forward to the start of the next sibling method (4-space indent | |
| # `def ` or class-end) — they all start with exactly 4 spaces of indent. | |
| rest = text[start:] | |
| # Skip past the function header and find the next `\n def ` or | |
| # `\n @staticmethod` boundary. | |
| end_marker_offsets = [] | |
| for marker in ("\n @", "\n def ", "\nclass "): | |
| idx = rest.find(marker, len("def generalized_jsd_loss(")) | |
| if idx > 0: | |
| end_marker_offsets.append(idx) | |
| if not end_marker_offsets: | |
| return None | |
| fn_text = rest[: min(end_marker_offsets)] | |
| # Dedent (the source lines are 4-space indented as a class method). | |
| fn_text = "\n".join( | |
| line[4:] if line.startswith(" ") else line for line in fn_text.splitlines() | |
| ) | |
| # Exec into a fresh namespace with torch + F available. | |
| import torch.nn.functional as F # noqa: F401 (used by exec'd code) | |
| namespace: dict = {"torch": torch, "F": F} | |
| exec(compile(fn_text, str(_OPSD_TRAINER_PATH), "exec"), namespace) | |
| fn = namespace.get("generalized_jsd_loss") | |
| return fn | |
| _UPSTREAM_FN = _load_upstream() | |
| _SKIP_REASON = ( | |
| f"upstream OPSD clone not found at {_OPSD_TRAINER_PATH} " | |
| f"(set $OPSD_CLONE or `git clone --depth 1 https://github.com/siyan-zhao/OPSD {_OPSD_CLONE}`)" | |
| ) | |
| # ---------------------------------------------------------------------- | |
| # Parity grid | |
| # ---------------------------------------------------------------------- | |
| _SHAPES = [ | |
| (1, 4, 16), | |
| (2, 8, 32), | |
| (3, 5, 64), | |
| (1, 16, 8), | |
| (4, 3, 24), | |
| ] | |
| _BETAS = [0.0, 0.5, 1.0] | |
| def test_parity_unmasked(shape, beta): | |
| """Our `generalized_jsd_loss` must match upstream within 1e-5 atol.""" | |
| B, T, V = shape | |
| g = torch.Generator().manual_seed(13 + B * 31 + T * 17 + V) | |
| student = torch.randn(B, T, V, generator=g, dtype=torch.float64) | |
| teacher = torch.randn(B, T, V, generator=g, dtype=torch.float64) | |
| ours = generalized_jsd_loss(student, teacher, beta=beta) | |
| theirs = _UPSTREAM_FN(student, teacher, beta=beta) # type: ignore[misc] | |
| assert torch.allclose(ours, theirs, atol=1e-5, rtol=1e-5), ( | |
| f"mismatch at shape={shape} beta={beta}: ours={ours.item()} theirs={theirs.item()}" | |
| ) | |
| def test_parity_masked(shape, beta): | |
| """Same parity but with a labels mask that ignores ~half the tokens.""" | |
| B, T, V = shape | |
| g = torch.Generator().manual_seed(101 + B * 7 + T * 11 + V) | |
| student = torch.randn(B, T, V, generator=g, dtype=torch.float64) | |
| teacher = torch.randn(B, T, V, generator=g, dtype=torch.float64) | |
| # Random valid/ignored mask: -100 for ignored, anything else for valid. | |
| labels = torch.randint(0, 2, (B, T), generator=g) | |
| labels = torch.where(labels == 0, torch.full_like(labels, -100), labels) | |
| ours = generalized_jsd_loss(student, teacher, labels=labels, beta=beta) | |
| theirs = _UPSTREAM_FN(student, teacher, labels=labels, beta=beta) # type: ignore[misc] | |
| assert torch.allclose(ours, theirs, atol=1e-5, rtol=1e-5), ( | |
| f"mismatch at shape={shape} beta={beta}: ours={ours.item()} theirs={theirs.item()}" | |
| ) | |
| def test_parity_temperature_and_topk(): | |
| """Spot-check the temperature + top_k branches against upstream.""" | |
| g = torch.Generator().manual_seed(42) | |
| student = torch.randn(2, 6, 32, generator=g, dtype=torch.float64) | |
| teacher = torch.randn(2, 6, 32, generator=g, dtype=torch.float64) | |
| for beta in (0.0, 0.3, 0.5, 0.7, 1.0): | |
| ours = generalized_jsd_loss(student, teacher, beta=beta, temperature=2.0, top_k=8) | |
| theirs = _UPSTREAM_FN( # type: ignore[misc] | |
| student, teacher, beta=beta, temperature=2.0, top_k=8 | |
| ) | |
| assert torch.allclose(ours, theirs, atol=1e-5, rtol=1e-5), ( | |
| f"temp+topk parity failed at beta={beta}: ours={ours.item()} theirs={theirs.item()}" | |
| ) | |