"""Numerical parity test against the upstream OPSD reference. Loads `OPSDTrainer.generalized_jsd_loss` from a clone of siyan-zhao/OPSD at /tmp/opsd-clone (override with $OPSD_CLONE) and asserts our re-implementation in `composer_replication.opsd` matches it byte-for-byte across a grid of shapes and β values. Skips cleanly when the upstream clone is absent. Why this lives in `tests/` rather than docs: numerical parity is the contract for this lift. If a future refactor of `generalized_jsd_loss` silently shifts gradients again, this test fails immediately. """ from __future__ import annotations import importlib.util import os import sys from pathlib import Path import pytest import torch from composer_replication.opsd import generalized_jsd_loss # ---------------------------------------------------------------------- # Locate upstream OPSDTrainer.generalized_jsd_loss # ---------------------------------------------------------------------- _OPSD_CLONE = Path(os.environ.get("OPSD_CLONE", "/tmp/opsd-clone")) _OPSD_TRAINER_PATH = _OPSD_CLONE / "opsd_trainer.py" def _load_upstream(): """Import OPSDTrainer.generalized_jsd_loss from a local clone, isolated. The upstream `opsd_trainer.py` imports heavyweight TRL / transformers machinery at module scope, which we do not want to drag into the test process. We instead extract the static method by parsing the source text and exec-ing only that function body — it depends only on `torch` and `torch.nn.functional`, which are already importable. """ if not _OPSD_TRAINER_PATH.exists(): return None text = _OPSD_TRAINER_PATH.read_text() # Pull out the function block. It starts with `def generalized_jsd_loss(` # under `class OPSDTrainer` and ends at the next top-of-class `def `. start = text.find("def generalized_jsd_loss(") if start < 0: return None # Walk forward to the start of the next sibling method (4-space indent # `def ` or class-end) — they all start with exactly 4 spaces of indent. rest = text[start:] # Skip past the function header and find the next `\n def ` or # `\n @staticmethod` boundary. end_marker_offsets = [] for marker in ("\n @", "\n def ", "\nclass "): idx = rest.find(marker, len("def generalized_jsd_loss(")) if idx > 0: end_marker_offsets.append(idx) if not end_marker_offsets: return None fn_text = rest[: min(end_marker_offsets)] # Dedent (the source lines are 4-space indented as a class method). fn_text = "\n".join( line[4:] if line.startswith(" ") else line for line in fn_text.splitlines() ) # Exec into a fresh namespace with torch + F available. import torch.nn.functional as F # noqa: F401 (used by exec'd code) namespace: dict = {"torch": torch, "F": F} exec(compile(fn_text, str(_OPSD_TRAINER_PATH), "exec"), namespace) fn = namespace.get("generalized_jsd_loss") return fn _UPSTREAM_FN = _load_upstream() _SKIP_REASON = ( f"upstream OPSD clone not found at {_OPSD_TRAINER_PATH} " f"(set $OPSD_CLONE or `git clone --depth 1 https://github.com/siyan-zhao/OPSD {_OPSD_CLONE}`)" ) # ---------------------------------------------------------------------- # Parity grid # ---------------------------------------------------------------------- _SHAPES = [ (1, 4, 16), (2, 8, 32), (3, 5, 64), (1, 16, 8), (4, 3, 24), ] _BETAS = [0.0, 0.5, 1.0] @pytest.mark.skipif(_UPSTREAM_FN is None, reason=_SKIP_REASON) @pytest.mark.parametrize("shape", _SHAPES) @pytest.mark.parametrize("beta", _BETAS) def test_parity_unmasked(shape, beta): """Our `generalized_jsd_loss` must match upstream within 1e-5 atol.""" B, T, V = shape g = torch.Generator().manual_seed(13 + B * 31 + T * 17 + V) student = torch.randn(B, T, V, generator=g, dtype=torch.float64) teacher = torch.randn(B, T, V, generator=g, dtype=torch.float64) ours = generalized_jsd_loss(student, teacher, beta=beta) theirs = _UPSTREAM_FN(student, teacher, beta=beta) # type: ignore[misc] assert torch.allclose(ours, theirs, atol=1e-5, rtol=1e-5), ( f"mismatch at shape={shape} beta={beta}: ours={ours.item()} theirs={theirs.item()}" ) @pytest.mark.skipif(_UPSTREAM_FN is None, reason=_SKIP_REASON) @pytest.mark.parametrize("shape", _SHAPES) @pytest.mark.parametrize("beta", _BETAS) def test_parity_masked(shape, beta): """Same parity but with a labels mask that ignores ~half the tokens.""" B, T, V = shape g = torch.Generator().manual_seed(101 + B * 7 + T * 11 + V) student = torch.randn(B, T, V, generator=g, dtype=torch.float64) teacher = torch.randn(B, T, V, generator=g, dtype=torch.float64) # Random valid/ignored mask: -100 for ignored, anything else for valid. labels = torch.randint(0, 2, (B, T), generator=g) labels = torch.where(labels == 0, torch.full_like(labels, -100), labels) ours = generalized_jsd_loss(student, teacher, labels=labels, beta=beta) theirs = _UPSTREAM_FN(student, teacher, labels=labels, beta=beta) # type: ignore[misc] assert torch.allclose(ours, theirs, atol=1e-5, rtol=1e-5), ( f"mismatch at shape={shape} beta={beta}: ours={ours.item()} theirs={theirs.item()}" ) @pytest.mark.skipif(_UPSTREAM_FN is None, reason=_SKIP_REASON) def test_parity_temperature_and_topk(): """Spot-check the temperature + top_k branches against upstream.""" g = torch.Generator().manual_seed(42) student = torch.randn(2, 6, 32, generator=g, dtype=torch.float64) teacher = torch.randn(2, 6, 32, generator=g, dtype=torch.float64) for beta in (0.0, 0.3, 0.5, 0.7, 1.0): ours = generalized_jsd_loss(student, teacher, beta=beta, temperature=2.0, top_k=8) theirs = _UPSTREAM_FN( # type: ignore[misc] student, teacher, beta=beta, temperature=2.0, top_k=8 ) assert torch.allclose(ours, theirs, atol=1e-5, rtol=1e-5), ( f"temp+topk parity failed at beta={beta}: ours={ours.item()} theirs={theirs.item()}" )