File size: 6,102 Bytes
e5add15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
"""Numerical parity test against the upstream OPSD reference.

Loads `OPSDTrainer.generalized_jsd_loss` from a clone of siyan-zhao/OPSD at
/tmp/opsd-clone (override with $OPSD_CLONE) and asserts our re-implementation
in `composer_replication.opsd` matches it byte-for-byte across a grid of
shapes and β values. Skips cleanly when the upstream clone is absent.

Why this lives in `tests/` rather than docs: numerical parity is the
contract for this lift. If a future refactor of `generalized_jsd_loss`
silently shifts gradients again, this test fails immediately.
"""

from __future__ import annotations

import importlib.util
import os
import sys
from pathlib import Path

import pytest
import torch

from composer_replication.opsd import generalized_jsd_loss

# ----------------------------------------------------------------------
# Locate upstream OPSDTrainer.generalized_jsd_loss
# ----------------------------------------------------------------------

_OPSD_CLONE = Path(os.environ.get("OPSD_CLONE", "/tmp/opsd-clone"))
_OPSD_TRAINER_PATH = _OPSD_CLONE / "opsd_trainer.py"


def _load_upstream():
    """Import OPSDTrainer.generalized_jsd_loss from a local clone, isolated.

    The upstream `opsd_trainer.py` imports heavyweight TRL / transformers
    machinery at module scope, which we do not want to drag into the test
    process. We instead extract the static method by parsing the source
    text and exec-ing only that function body — it depends only on
    `torch` and `torch.nn.functional`, which are already importable.
    """
    if not _OPSD_TRAINER_PATH.exists():
        return None

    text = _OPSD_TRAINER_PATH.read_text()
    # Pull out the function block. It starts with `def generalized_jsd_loss(`
    # under `class OPSDTrainer` and ends at the next top-of-class `def `.
    start = text.find("def generalized_jsd_loss(")
    if start < 0:
        return None
    # Walk forward to the start of the next sibling method (4-space indent
    # `def ` or class-end) — they all start with exactly 4 spaces of indent.
    rest = text[start:]
    # Skip past the function header and find the next `\n    def ` or
    # `\n    @staticmethod` boundary.
    end_marker_offsets = []
    for marker in ("\n    @", "\n    def ", "\nclass "):
        idx = rest.find(marker, len("def generalized_jsd_loss("))
        if idx > 0:
            end_marker_offsets.append(idx)
    if not end_marker_offsets:
        return None
    fn_text = rest[: min(end_marker_offsets)]

    # Dedent (the source lines are 4-space indented as a class method).
    fn_text = "\n".join(
        line[4:] if line.startswith("    ") else line for line in fn_text.splitlines()
    )

    # Exec into a fresh namespace with torch + F available.
    import torch.nn.functional as F  # noqa: F401  (used by exec'd code)

    namespace: dict = {"torch": torch, "F": F}
    exec(compile(fn_text, str(_OPSD_TRAINER_PATH), "exec"), namespace)
    fn = namespace.get("generalized_jsd_loss")
    return fn


_UPSTREAM_FN = _load_upstream()
_SKIP_REASON = (
    f"upstream OPSD clone not found at {_OPSD_TRAINER_PATH} "
    f"(set $OPSD_CLONE or `git clone --depth 1 https://github.com/siyan-zhao/OPSD {_OPSD_CLONE}`)"
)


# ----------------------------------------------------------------------
# Parity grid
# ----------------------------------------------------------------------

_SHAPES = [
    (1, 4, 16),
    (2, 8, 32),
    (3, 5, 64),
    (1, 16, 8),
    (4, 3, 24),
]
_BETAS = [0.0, 0.5, 1.0]


@pytest.mark.skipif(_UPSTREAM_FN is None, reason=_SKIP_REASON)
@pytest.mark.parametrize("shape", _SHAPES)
@pytest.mark.parametrize("beta", _BETAS)
def test_parity_unmasked(shape, beta):
    """Our `generalized_jsd_loss` must match upstream within 1e-5 atol."""
    B, T, V = shape
    g = torch.Generator().manual_seed(13 + B * 31 + T * 17 + V)
    student = torch.randn(B, T, V, generator=g, dtype=torch.float64)
    teacher = torch.randn(B, T, V, generator=g, dtype=torch.float64)

    ours = generalized_jsd_loss(student, teacher, beta=beta)
    theirs = _UPSTREAM_FN(student, teacher, beta=beta)  # type: ignore[misc]

    assert torch.allclose(ours, theirs, atol=1e-5, rtol=1e-5), (
        f"mismatch at shape={shape} beta={beta}: ours={ours.item()} theirs={theirs.item()}"
    )


@pytest.mark.skipif(_UPSTREAM_FN is None, reason=_SKIP_REASON)
@pytest.mark.parametrize("shape", _SHAPES)
@pytest.mark.parametrize("beta", _BETAS)
def test_parity_masked(shape, beta):
    """Same parity but with a labels mask that ignores ~half the tokens."""
    B, T, V = shape
    g = torch.Generator().manual_seed(101 + B * 7 + T * 11 + V)
    student = torch.randn(B, T, V, generator=g, dtype=torch.float64)
    teacher = torch.randn(B, T, V, generator=g, dtype=torch.float64)
    # Random valid/ignored mask: -100 for ignored, anything else for valid.
    labels = torch.randint(0, 2, (B, T), generator=g)
    labels = torch.where(labels == 0, torch.full_like(labels, -100), labels)

    ours = generalized_jsd_loss(student, teacher, labels=labels, beta=beta)
    theirs = _UPSTREAM_FN(student, teacher, labels=labels, beta=beta)  # type: ignore[misc]

    assert torch.allclose(ours, theirs, atol=1e-5, rtol=1e-5), (
        f"mismatch at shape={shape} beta={beta}: ours={ours.item()} theirs={theirs.item()}"
    )


@pytest.mark.skipif(_UPSTREAM_FN is None, reason=_SKIP_REASON)
def test_parity_temperature_and_topk():
    """Spot-check the temperature + top_k branches against upstream."""
    g = torch.Generator().manual_seed(42)
    student = torch.randn(2, 6, 32, generator=g, dtype=torch.float64)
    teacher = torch.randn(2, 6, 32, generator=g, dtype=torch.float64)

    for beta in (0.0, 0.3, 0.5, 0.7, 1.0):
        ours = generalized_jsd_loss(student, teacher, beta=beta, temperature=2.0, top_k=8)
        theirs = _UPSTREAM_FN(  # type: ignore[misc]
            student, teacher, beta=beta, temperature=2.0, top_k=8
        )
        assert torch.allclose(ours, theirs, atol=1e-5, rtol=1e-5), (
            f"temp+topk parity failed at beta={beta}: ours={ours.item()} theirs={theirs.item()}"
        )