"""Upstream-parity test for TAID. This test compares our `taid_loss` against the reference implementation in ``SakanaAI/TAID/src/distil_losses/taid.py`` + ``fkl.py``. The upstream clone is expected at ``/tmp/taid-clone``; if absent, the test is skipped. To run:: git clone --depth 1 https://github.com/SakanaAI/TAID /tmp/taid-clone pytest composer_replication/distillation/tests/test_taid_parity.py Parity is asserted at atol/rtol = 1e-5 over a small batch on CPU. """ from __future__ import annotations import importlib.util import os import sys import pytest import torch from composer_replication.distillation import taid_loss UPSTREAM_PATH = os.environ.get("TAID_UPSTREAM_PATH", "/tmp/taid-clone") UPSTREAM_TAID_PY = os.path.join(UPSTREAM_PATH, "src", "distil_losses", "taid.py") UPSTREAM_FKL_PY = os.path.join(UPSTREAM_PATH, "src", "distil_losses", "fkl.py") def _load_upstream_forward_kl(): """Inline-load just the `forward_kl` function from the upstream clone. We avoid importing the full upstream module because it depends on `lightning` and a relative `.base` import. Instead we read the file and exec just the function body. """ if not (os.path.isfile(UPSTREAM_TAID_PY) and os.path.isfile(UPSTREAM_FKL_PY)): return None src = open(UPSTREAM_FKL_PY).read() # Strip the module-level `from .base import DistilLoss` so we can exec # standalone — only forward_kl is needed for parity. sandbox: dict = {} # Build a minimal namespace that mimics the upstream imports. exec( "from typing import Optional\n" "import torch\n" "from torch.nn import functional as F\n", sandbox, ) # Append the forward_kl function definition. fwd_kl_src = src.split("def forward_kl(", 1)[1] fwd_kl_src = "def forward_kl(" + fwd_kl_src.split("\nclass ", 1)[0] exec(fwd_kl_src, sandbox) return sandbox["forward_kl"] def _upstream_compute_loss(student_logits, teacher_logits, mask, t): """Replicate `TAID.compute_loss` from upstream taid.py:66-80 inline. Same arithmetic; we just don't instantiate the LightningModule bookkeeping around it. """ forward_kl = _load_upstream_forward_kl() if forward_kl is None: return None import torch.nn.functional as F p_t = (1 - t) * student_logits.detach() + t * teacher_logits p_t = F.softmax(p_t, dim=-1, dtype=torch.float32) distil_loss = forward_kl( logits=student_logits, teacher_logits=teacher_logits, mask=mask, teacher_probs=p_t, ) return distil_loss @pytest.mark.skipif( not os.path.isfile(UPSTREAM_TAID_PY), reason=( f"Upstream TAID clone not found at {UPSTREAM_PATH}. " f"Run: git clone --depth 1 https://github.com/SakanaAI/TAID {UPSTREAM_PATH}" ), ) @pytest.mark.parametrize("t", [0.0, 0.1, 0.4, 0.5, 0.9, 1.0]) def test_taid_parity_against_upstream(t): """Our taid_loss matches upstream TAID.compute_loss(...) within atol=1e-5. Tests across the full t-range, on a fixed-seed batch with random logits + a non-trivial mask. """ torch.manual_seed(0) B, T, V = 2, 4, 16 student = torch.randn(B, T, V, requires_grad=True) teacher = torch.randn(B, T, V) mask = torch.tensor([[1, 1, 1, 0], [1, 1, 0, 0]], dtype=torch.float32) ours = taid_loss(student, teacher, mask, t=t) theirs = _upstream_compute_loss(student, teacher, mask, t=t) assert theirs is not None, "upstream forward_kl could not be loaded" torch.testing.assert_close(ours, theirs, atol=1e-5, rtol=1e-5) @pytest.mark.skipif( not os.path.isfile(UPSTREAM_TAID_PY), reason=f"Upstream TAID clone not found at {UPSTREAM_PATH}.", ) def test_taid_parity_with_full_mask(): """Sanity: full-mask path also matches upstream.""" torch.manual_seed(1) B, T, V = 1, 3, 8 student = torch.randn(B, T, V, requires_grad=True) teacher = torch.randn(B, T, V) mask = torch.ones(B, T) ours = taid_loss(student, teacher, mask, t=0.4) theirs = _upstream_compute_loss(student, teacher, mask, t=0.4) assert theirs is not None torch.testing.assert_close(ours, theirs, atol=1e-5, rtol=1e-5)