Codeseys's picture
Wave 15: 4-angle multi-model self-critique caught 2 math BLOCKERs in primary loss kernels; fixed against upstream byte-for-byte + GSM8K example + ergonomics
e5add15
raw
history blame
6.1 kB
"""Numerical parity test against the upstream OPSD reference.
Loads `OPSDTrainer.generalized_jsd_loss` from a clone of siyan-zhao/OPSD at
/tmp/opsd-clone (override with $OPSD_CLONE) and asserts our re-implementation
in `composer_replication.opsd` matches it byte-for-byte across a grid of
shapes and β values. Skips cleanly when the upstream clone is absent.
Why this lives in `tests/` rather than docs: numerical parity is the
contract for this lift. If a future refactor of `generalized_jsd_loss`
silently shifts gradients again, this test fails immediately.
"""
from __future__ import annotations
import importlib.util
import os
import sys
from pathlib import Path
import pytest
import torch
from composer_replication.opsd import generalized_jsd_loss
# ----------------------------------------------------------------------
# Locate upstream OPSDTrainer.generalized_jsd_loss
# ----------------------------------------------------------------------
_OPSD_CLONE = Path(os.environ.get("OPSD_CLONE", "/tmp/opsd-clone"))
_OPSD_TRAINER_PATH = _OPSD_CLONE / "opsd_trainer.py"
def _load_upstream():
"""Import OPSDTrainer.generalized_jsd_loss from a local clone, isolated.
The upstream `opsd_trainer.py` imports heavyweight TRL / transformers
machinery at module scope, which we do not want to drag into the test
process. We instead extract the static method by parsing the source
text and exec-ing only that function body — it depends only on
`torch` and `torch.nn.functional`, which are already importable.
"""
if not _OPSD_TRAINER_PATH.exists():
return None
text = _OPSD_TRAINER_PATH.read_text()
# Pull out the function block. It starts with `def generalized_jsd_loss(`
# under `class OPSDTrainer` and ends at the next top-of-class `def `.
start = text.find("def generalized_jsd_loss(")
if start < 0:
return None
# Walk forward to the start of the next sibling method (4-space indent
# `def ` or class-end) — they all start with exactly 4 spaces of indent.
rest = text[start:]
# Skip past the function header and find the next `\n def ` or
# `\n @staticmethod` boundary.
end_marker_offsets = []
for marker in ("\n @", "\n def ", "\nclass "):
idx = rest.find(marker, len("def generalized_jsd_loss("))
if idx > 0:
end_marker_offsets.append(idx)
if not end_marker_offsets:
return None
fn_text = rest[: min(end_marker_offsets)]
# Dedent (the source lines are 4-space indented as a class method).
fn_text = "\n".join(
line[4:] if line.startswith(" ") else line for line in fn_text.splitlines()
)
# Exec into a fresh namespace with torch + F available.
import torch.nn.functional as F # noqa: F401 (used by exec'd code)
namespace: dict = {"torch": torch, "F": F}
exec(compile(fn_text, str(_OPSD_TRAINER_PATH), "exec"), namespace)
fn = namespace.get("generalized_jsd_loss")
return fn
_UPSTREAM_FN = _load_upstream()
_SKIP_REASON = (
f"upstream OPSD clone not found at {_OPSD_TRAINER_PATH} "
f"(set $OPSD_CLONE or `git clone --depth 1 https://github.com/siyan-zhao/OPSD {_OPSD_CLONE}`)"
)
# ----------------------------------------------------------------------
# Parity grid
# ----------------------------------------------------------------------
_SHAPES = [
(1, 4, 16),
(2, 8, 32),
(3, 5, 64),
(1, 16, 8),
(4, 3, 24),
]
_BETAS = [0.0, 0.5, 1.0]
@pytest.mark.skipif(_UPSTREAM_FN is None, reason=_SKIP_REASON)
@pytest.mark.parametrize("shape", _SHAPES)
@pytest.mark.parametrize("beta", _BETAS)
def test_parity_unmasked(shape, beta):
"""Our `generalized_jsd_loss` must match upstream within 1e-5 atol."""
B, T, V = shape
g = torch.Generator().manual_seed(13 + B * 31 + T * 17 + V)
student = torch.randn(B, T, V, generator=g, dtype=torch.float64)
teacher = torch.randn(B, T, V, generator=g, dtype=torch.float64)
ours = generalized_jsd_loss(student, teacher, beta=beta)
theirs = _UPSTREAM_FN(student, teacher, beta=beta) # type: ignore[misc]
assert torch.allclose(ours, theirs, atol=1e-5, rtol=1e-5), (
f"mismatch at shape={shape} beta={beta}: ours={ours.item()} theirs={theirs.item()}"
)
@pytest.mark.skipif(_UPSTREAM_FN is None, reason=_SKIP_REASON)
@pytest.mark.parametrize("shape", _SHAPES)
@pytest.mark.parametrize("beta", _BETAS)
def test_parity_masked(shape, beta):
"""Same parity but with a labels mask that ignores ~half the tokens."""
B, T, V = shape
g = torch.Generator().manual_seed(101 + B * 7 + T * 11 + V)
student = torch.randn(B, T, V, generator=g, dtype=torch.float64)
teacher = torch.randn(B, T, V, generator=g, dtype=torch.float64)
# Random valid/ignored mask: -100 for ignored, anything else for valid.
labels = torch.randint(0, 2, (B, T), generator=g)
labels = torch.where(labels == 0, torch.full_like(labels, -100), labels)
ours = generalized_jsd_loss(student, teacher, labels=labels, beta=beta)
theirs = _UPSTREAM_FN(student, teacher, labels=labels, beta=beta) # type: ignore[misc]
assert torch.allclose(ours, theirs, atol=1e-5, rtol=1e-5), (
f"mismatch at shape={shape} beta={beta}: ours={ours.item()} theirs={theirs.item()}"
)
@pytest.mark.skipif(_UPSTREAM_FN is None, reason=_SKIP_REASON)
def test_parity_temperature_and_topk():
"""Spot-check the temperature + top_k branches against upstream."""
g = torch.Generator().manual_seed(42)
student = torch.randn(2, 6, 32, generator=g, dtype=torch.float64)
teacher = torch.randn(2, 6, 32, generator=g, dtype=torch.float64)
for beta in (0.0, 0.3, 0.5, 0.7, 1.0):
ours = generalized_jsd_loss(student, teacher, beta=beta, temperature=2.0, top_k=8)
theirs = _UPSTREAM_FN( # type: ignore[misc]
student, teacher, beta=beta, temperature=2.0, top_k=8
)
assert torch.allclose(ours, theirs, atol=1e-5, rtol=1e-5), (
f"temp+topk parity failed at beta={beta}: ours={ours.item()} theirs={theirs.item()}"
)