Reinforcement Learning
Transformers
English
post-training
distillation
agentic-coding
composer-2.5
cursor
kimi-k2
grpo
dapo
diloco
openenv
trl
verl
research
methodology
Instructions to use Codeseys/composer-replication-framework with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Codeseys/composer-replication-framework with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Codeseys/composer-replication-framework", dtype="auto") - Notebooks
- Google Colab
- Kaggle
| """Upstream-parity test for TAID. | |
| This test compares our `taid_loss` against the reference implementation | |
| in ``SakanaAI/TAID/src/distil_losses/taid.py`` + ``fkl.py``. The upstream | |
| clone is expected at ``/tmp/taid-clone``; if absent, the test is skipped. | |
| To run:: | |
| git clone --depth 1 https://github.com/SakanaAI/TAID /tmp/taid-clone | |
| pytest composer_replication/distillation/tests/test_taid_parity.py | |
| Parity is asserted at atol/rtol = 1e-5 over a small batch on CPU. | |
| """ | |
| from __future__ import annotations | |
| import importlib.util | |
| import os | |
| import sys | |
| import pytest | |
| import torch | |
| from composer_replication.distillation import taid_loss | |
| UPSTREAM_PATH = os.environ.get("TAID_UPSTREAM_PATH", "/tmp/taid-clone") | |
| UPSTREAM_TAID_PY = os.path.join(UPSTREAM_PATH, "src", "distil_losses", "taid.py") | |
| UPSTREAM_FKL_PY = os.path.join(UPSTREAM_PATH, "src", "distil_losses", "fkl.py") | |
| def _load_upstream_forward_kl(): | |
| """Inline-load just the `forward_kl` function from the upstream clone. | |
| We avoid importing the full upstream module because it depends on | |
| `lightning` and a relative `.base` import. Instead we read the file and | |
| exec just the function body. | |
| """ | |
| if not (os.path.isfile(UPSTREAM_TAID_PY) and os.path.isfile(UPSTREAM_FKL_PY)): | |
| return None | |
| src = open(UPSTREAM_FKL_PY).read() | |
| # Strip the module-level `from .base import DistilLoss` so we can exec | |
| # standalone — only forward_kl is needed for parity. | |
| sandbox: dict = {} | |
| # Build a minimal namespace that mimics the upstream imports. | |
| exec( | |
| "from typing import Optional\n" | |
| "import torch\n" | |
| "from torch.nn import functional as F\n", | |
| sandbox, | |
| ) | |
| # Append the forward_kl function definition. | |
| fwd_kl_src = src.split("def forward_kl(", 1)[1] | |
| fwd_kl_src = "def forward_kl(" + fwd_kl_src.split("\nclass ", 1)[0] | |
| exec(fwd_kl_src, sandbox) | |
| return sandbox["forward_kl"] | |
| def _upstream_compute_loss(student_logits, teacher_logits, mask, t): | |
| """Replicate `TAID.compute_loss` from upstream taid.py:66-80 inline. | |
| Same arithmetic; we just don't instantiate the LightningModule | |
| bookkeeping around it. | |
| """ | |
| forward_kl = _load_upstream_forward_kl() | |
| if forward_kl is None: | |
| return None | |
| import torch.nn.functional as F | |
| p_t = (1 - t) * student_logits.detach() + t * teacher_logits | |
| p_t = F.softmax(p_t, dim=-1, dtype=torch.float32) | |
| distil_loss = forward_kl( | |
| logits=student_logits, | |
| teacher_logits=teacher_logits, | |
| mask=mask, | |
| teacher_probs=p_t, | |
| ) | |
| return distil_loss | |
| def test_taid_parity_against_upstream(t): | |
| """Our taid_loss matches upstream TAID.compute_loss(...) within atol=1e-5. | |
| Tests across the full t-range, on a fixed-seed batch with random logits + | |
| a non-trivial mask. | |
| """ | |
| torch.manual_seed(0) | |
| B, T, V = 2, 4, 16 | |
| student = torch.randn(B, T, V, requires_grad=True) | |
| teacher = torch.randn(B, T, V) | |
| mask = torch.tensor([[1, 1, 1, 0], [1, 1, 0, 0]], dtype=torch.float32) | |
| ours = taid_loss(student, teacher, mask, t=t) | |
| theirs = _upstream_compute_loss(student, teacher, mask, t=t) | |
| assert theirs is not None, "upstream forward_kl could not be loaded" | |
| torch.testing.assert_close(ours, theirs, atol=1e-5, rtol=1e-5) | |
| def test_taid_parity_with_full_mask(): | |
| """Sanity: full-mask path also matches upstream.""" | |
| torch.manual_seed(1) | |
| B, T, V = 1, 3, 8 | |
| student = torch.randn(B, T, V, requires_grad=True) | |
| teacher = torch.randn(B, T, V) | |
| mask = torch.ones(B, T) | |
| ours = taid_loss(student, teacher, mask, t=0.4) | |
| theirs = _upstream_compute_loss(student, teacher, mask, t=0.4) | |
| assert theirs is not None | |
| torch.testing.assert_close(ours, theirs, atol=1e-5, rtol=1e-5) | |