Codeseys's picture
Wave 15: 4-angle multi-model self-critique caught 2 math BLOCKERs in primary loss kernels; fixed against upstream byte-for-byte + GSM8K example + ergonomics
e5add15
"""Upstream-parity test for TAID.
This test compares our `taid_loss` against the reference implementation
in ``SakanaAI/TAID/src/distil_losses/taid.py`` + ``fkl.py``. The upstream
clone is expected at ``/tmp/taid-clone``; if absent, the test is skipped.
To run::
git clone --depth 1 https://github.com/SakanaAI/TAID /tmp/taid-clone
pytest composer_replication/distillation/tests/test_taid_parity.py
Parity is asserted at atol/rtol = 1e-5 over a small batch on CPU.
"""
from __future__ import annotations
import importlib.util
import os
import sys
import pytest
import torch
from composer_replication.distillation import taid_loss
UPSTREAM_PATH = os.environ.get("TAID_UPSTREAM_PATH", "/tmp/taid-clone")
UPSTREAM_TAID_PY = os.path.join(UPSTREAM_PATH, "src", "distil_losses", "taid.py")
UPSTREAM_FKL_PY = os.path.join(UPSTREAM_PATH, "src", "distil_losses", "fkl.py")
def _load_upstream_forward_kl():
"""Inline-load just the `forward_kl` function from the upstream clone.
We avoid importing the full upstream module because it depends on
`lightning` and a relative `.base` import. Instead we read the file and
exec just the function body.
"""
if not (os.path.isfile(UPSTREAM_TAID_PY) and os.path.isfile(UPSTREAM_FKL_PY)):
return None
src = open(UPSTREAM_FKL_PY).read()
# Strip the module-level `from .base import DistilLoss` so we can exec
# standalone — only forward_kl is needed for parity.
sandbox: dict = {}
# Build a minimal namespace that mimics the upstream imports.
exec(
"from typing import Optional\n"
"import torch\n"
"from torch.nn import functional as F\n",
sandbox,
)
# Append the forward_kl function definition.
fwd_kl_src = src.split("def forward_kl(", 1)[1]
fwd_kl_src = "def forward_kl(" + fwd_kl_src.split("\nclass ", 1)[0]
exec(fwd_kl_src, sandbox)
return sandbox["forward_kl"]
def _upstream_compute_loss(student_logits, teacher_logits, mask, t):
"""Replicate `TAID.compute_loss` from upstream taid.py:66-80 inline.
Same arithmetic; we just don't instantiate the LightningModule
bookkeeping around it.
"""
forward_kl = _load_upstream_forward_kl()
if forward_kl is None:
return None
import torch.nn.functional as F
p_t = (1 - t) * student_logits.detach() + t * teacher_logits
p_t = F.softmax(p_t, dim=-1, dtype=torch.float32)
distil_loss = forward_kl(
logits=student_logits,
teacher_logits=teacher_logits,
mask=mask,
teacher_probs=p_t,
)
return distil_loss
@pytest.mark.skipif(
not os.path.isfile(UPSTREAM_TAID_PY),
reason=(
f"Upstream TAID clone not found at {UPSTREAM_PATH}. "
f"Run: git clone --depth 1 https://github.com/SakanaAI/TAID {UPSTREAM_PATH}"
),
)
@pytest.mark.parametrize("t", [0.0, 0.1, 0.4, 0.5, 0.9, 1.0])
def test_taid_parity_against_upstream(t):
"""Our taid_loss matches upstream TAID.compute_loss(...) within atol=1e-5.
Tests across the full t-range, on a fixed-seed batch with random logits +
a non-trivial mask.
"""
torch.manual_seed(0)
B, T, V = 2, 4, 16
student = torch.randn(B, T, V, requires_grad=True)
teacher = torch.randn(B, T, V)
mask = torch.tensor([[1, 1, 1, 0], [1, 1, 0, 0]], dtype=torch.float32)
ours = taid_loss(student, teacher, mask, t=t)
theirs = _upstream_compute_loss(student, teacher, mask, t=t)
assert theirs is not None, "upstream forward_kl could not be loaded"
torch.testing.assert_close(ours, theirs, atol=1e-5, rtol=1e-5)
@pytest.mark.skipif(
not os.path.isfile(UPSTREAM_TAID_PY),
reason=f"Upstream TAID clone not found at {UPSTREAM_PATH}.",
)
def test_taid_parity_with_full_mask():
"""Sanity: full-mask path also matches upstream."""
torch.manual_seed(1)
B, T, V = 1, 3, 8
student = torch.randn(B, T, V, requires_grad=True)
teacher = torch.randn(B, T, V)
mask = torch.ones(B, T)
ours = taid_loss(student, teacher, mask, t=0.4)
theirs = _upstream_compute_loss(student, teacher, mask, t=0.4)
assert theirs is not None
torch.testing.assert_close(ours, theirs, atol=1e-5, rtol=1e-5)