Instructions to use Codeseys/composer-replication-framework with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Codeseys/composer-replication-framework with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Codeseys/composer-replication-framework", dtype="auto") - Notebooks
- Google Colab
- Kaggle
feat(wave-b): ADR-013 LMA integration + B4 end-to-end SDPO-fires proof + doc refresh
Browse filesADR-013: composer_replication/integrations/altered_minds/ (generic, framework-side)
- MMLUFormatReward: structured-answer reward, scores only the final letter (not
rationale style), unparseable/multiple-answer penalties, option-order randomization,
always-C exploit detectable via logged option distribution.
- dual_kl_logger: KL(policy||altered-init) AND KL(policy||unaltered-base) as the
washout-vs-amplification instrument (neither optimized).
- channel_ladder_configs: A0-A4 isolated-channel ladder (alpha=0.02/beta=0.05),
replacing the uninterpretable combined alpha=0.2/beta=0.4 recipe per the
cross-family research critique (SDPO can AMPLIFY an altered model).
B4: examples/altered_minds_channel_ladder/run.py PROVES the SDPO channel FIRES
nonzero (JSD=0.0565, grad flows) through the REAL shipped collator alignment
indices — the thing the old smoke could not show (it only proved init). Honest
stub-with-differing-tokens proof; real-model path gated behind ALTERED_MINDS_REAL_MODEL=1.
ALTERED_MINDS_TIE_IN.md Phase-3 hyperparams superseded -> ADR-013 ladder.
BACKLOG.md + VISION_VALIDATION.md refreshed to actual shipped state.
227 passed / 16 skipped (was 210). NO real LMA checkpoint / Modal / budget spend
(user-gated). Workers: 2x Opus-4.8 + Gemini-3.1-pro (docs).
- BACKLOG.md +26 -19
- composer_replication/integrations/__init__.py +0 -0
- composer_replication/integrations/altered_minds/__init__.py +48 -0
- composer_replication/integrations/altered_minds/kl_logging.py +110 -0
- composer_replication/integrations/altered_minds/ladder.py +97 -0
- composer_replication/integrations/altered_minds/reward.py +231 -0
- composer_replication/integrations/altered_minds/tests/__init__.py +0 -0
- composer_replication/integrations/altered_minds/tests/test_channel_ladder.py +362 -0
- docs/ALTERED_MINDS_TIE_IN.md +37 -10
- docs/VISION_VALIDATION.md +7 -0
- examples/altered_minds_channel_ladder/README.md +52 -0
- examples/altered_minds_channel_ladder/run.py +229 -0
|
@@ -1,8 +1,23 @@
|
|
| 1 |
# Backlog — Composer 2.5 Replication Framework
|
| 2 |
|
| 3 |
-
|
| 4 |
|
| 5 |
-
## Active items
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
### Spike 006 — Real HF model smoke (Wave 7)
|
| 8 |
|
|
@@ -60,23 +75,15 @@ Imported from `docs/VISION_VALIDATION.md` § 6 (gaps) + § 9 (gap-closers) at 20
|
|
| 60 |
4. README quickstart updated to `pip install -e .` + `python examples/qwen3_05b_quickstart/run.py`.
|
| 61 |
5. `pip install -e .` succeeds and quickstart runs end-to-end on CPU.
|
| 62 |
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
**
|
| 70 |
-
|
| 71 |
-
**
|
| 72 |
-
|
| 73 |
-
**Acceptance**:
|
| 74 |
-
1. ADR-001 says Modal is the right choice for this workload + estimate is < $5.
|
| 75 |
-
2. Modal app builds, runs `composer_total_loss` for 50 steps on Qwen2.5-0.5B-Instruct.
|
| 76 |
-
3. Loss curve + memory profile saved to `spikes/002a-mini/` and pulled to local.
|
| 77 |
-
4. No new shape / dtype bug surfaced vs CPU run.
|
| 78 |
-
|
| 79 |
-
**Estimate**: $1–3, 30 min wall-clock.
|
| 80 |
|
| 81 |
## Deferred (post-loop, GPU-gated)
|
| 82 |
|
|
|
|
| 1 |
# Backlog — Composer 2.5 Replication Framework
|
| 2 |
|
| 3 |
+
Updated 2026-05-29 to reflect shipped waves (ingestion, diloco, packaging, datagen+RL, ADR-011/012/013, cross-family review).
|
| 4 |
|
| 5 |
+
## Active items / Honest Gaps
|
| 6 |
+
|
| 7 |
+
### Framework/Docker substrate E2E (Hardware-blocked)
|
| 8 |
+
- We lack the local multi-node GPU environment to run the true 8-node DiLoCo + Docker/TorchForge orchestrator E2E tests. Currently isolated to unit-level and single-node pseudo-gradient checks.
|
| 9 |
+
|
| 10 |
+
### Real 8B LMA run (User-budget-gated)
|
| 11 |
+
- The framework is proven on Qwen-0.5B and 1.5B (GSM8K/SDPO math traces).
|
| 12 |
+
- The ultimate goal (Llama-3-8B full LMA run with α/β ablation over 10k SWE-bench traces) requires a multi-GPU Modal drop + significant compute budget.
|
| 13 |
+
|
| 14 |
+
## Modal-gated (if budget allows after gap-closers)
|
| 15 |
+
|
| 16 |
+
### Spike 002a-mini — Real GPU smoke (Phase 10)
|
| 17 |
+
**Closes**: the "did we ever run gradients on GPU" ambiguity — currently everything is CPU-only.
|
| 18 |
+
- Goal: dispatch a 30-min A10G smoke on Modal that runs Qwen2.5-0.5B-Instruct natively on GPU.
|
| 19 |
+
|
| 20 |
+
## Shipped (Past-Skeleton)
|
| 21 |
|
| 22 |
### Spike 006 — Real HF model smoke (Wave 7)
|
| 23 |
|
|
|
|
| 75 |
4. README quickstart updated to `pip install -e .` + `python examples/qwen3_05b_quickstart/run.py`.
|
| 76 |
5. `pip install -e .` succeeds and quickstart runs end-to-end on CPU.
|
| 77 |
|
| 78 |
+
### Post-Skeleton Waves (Datagen, Alignment, Quality)
|
| 79 |
+
- **Trace Ingestion**: Shipped (`composer_replication/ingestion/`).
|
| 80 |
+
- **DiLoCo**: Shipped (`composer_replication/diloco/` outer-loop pseudo optimizer).
|
| 81 |
+
- **Packaging**: Shipped (`pip install -e .` works perfectly).
|
| 82 |
+
- **ADR-008/009/010 (Datagen, Layered Hints, Dr.GRPO+SDPO)**: Shipped, examples documented.
|
| 83 |
+
- **Cross-Family Architectural Review**: Shipped (`docs/reviews/cross-family-adr-008-009-010-2026-05-29/`).
|
| 84 |
+
- **Alignment / V&V Closure**: ADR-011 (SDPO alignment indices), ADR-012 (close review findings), ADR-013 (LMA integration channel-ladder) shipped.
|
| 85 |
+
- **Test Suites**: 210 passed / 16 skipped.
|
| 86 |
+
- **Real Examples**: `examples/gsm8k_grpo/`, `examples/sdpo_with_real_traces_production/`.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
|
| 88 |
## Deferred (post-loop, GPU-gated)
|
| 89 |
|
|
File without changes
|
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""altered_minds — framework-side, generic LMA integration glue (ADR-013).
|
| 2 |
+
|
| 3 |
+
This package is the *model-agnostic* scaffold that lets the Composer Replication
|
| 4 |
+
Framework drive the sister project llm-mental-alterations (LMA): take a
|
| 5 |
+
personality-altered SFT checkpoint and apply the framework's 3-channel RL to ask
|
| 6 |
+
whether task-driven RL washes out, preserves, or AMPLIFIES the alteration's
|
| 7 |
+
cognitive-distortion signature.
|
| 8 |
+
|
| 9 |
+
Nothing here loads an LMA checkpoint, calls Modal, or spends budget — that is
|
| 10 |
+
explicitly user-gated (ADR-013 "out of scope"). This package provides:
|
| 11 |
+
|
| 12 |
+
- ``MMLUFormatReward`` : structured-answer reward (final letter + format
|
| 13 |
+
only; never rationale style). Plus
|
| 14 |
+
``randomize_options`` and a logged option
|
| 15 |
+
distribution so an "always C" exploit is
|
| 16 |
+
detectable.
|
| 17 |
+
- ``dual_kl_logger`` : logs KL(policy||altered_init) AND KL(policy||base)
|
| 18 |
+
each step — the washout/amplification instrument.
|
| 19 |
+
- ``channel_ladder_configs``: the A0-A4 isolated-channel ladder that REPLACES
|
| 20 |
+
the old combined alpha=0.2/beta=0.4 recipe.
|
| 21 |
+
|
| 22 |
+
See docs/adrs/ADR-013-lma-integration-channel-ladder.md.
|
| 23 |
+
"""
|
| 24 |
+
from __future__ import annotations
|
| 25 |
+
|
| 26 |
+
from composer_replication.integrations.altered_minds.kl_logging import (
|
| 27 |
+
dual_kl_logger,
|
| 28 |
+
token_mean_kl,
|
| 29 |
+
)
|
| 30 |
+
from composer_replication.integrations.altered_minds.ladder import (
|
| 31 |
+
LADDER_KL_BETA,
|
| 32 |
+
channel_ladder_configs,
|
| 33 |
+
)
|
| 34 |
+
from composer_replication.integrations.altered_minds.reward import (
|
| 35 |
+
MMLUFormatReward,
|
| 36 |
+
parse_final_answer,
|
| 37 |
+
randomize_options,
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
__all__ = [
|
| 41 |
+
"MMLUFormatReward",
|
| 42 |
+
"parse_final_answer",
|
| 43 |
+
"randomize_options",
|
| 44 |
+
"dual_kl_logger",
|
| 45 |
+
"token_mean_kl",
|
| 46 |
+
"channel_ladder_configs",
|
| 47 |
+
"LADDER_KL_BETA",
|
| 48 |
+
]
|
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""kl_logging.py — dual_kl_logger (ADR-013, framework-side, generic).
|
| 2 |
+
|
| 3 |
+
The washout/amplification instrument. Given per-token logprobs from three
|
| 4 |
+
forward passes on the SAME answer+reasoning tokens:
|
| 5 |
+
|
| 6 |
+
- policy: the model currently being RL-trained
|
| 7 |
+
- altered_init: the altered SFT checkpoint the run STARTED from (the locus
|
| 8 |
+
of the cognitive-distortion signature)
|
| 9 |
+
- unaltered_base: the original base model BEFORE personality SFT
|
| 10 |
+
|
| 11 |
+
returns ``{'kl_to_altered_init': float, 'kl_to_base': float}``.
|
| 12 |
+
|
| 13 |
+
NEITHER KL is optimized by default — both are diagnostics:
|
| 14 |
+
- ``kl_to_altered_init`` rising means the policy is moving AWAY from the
|
| 15 |
+
altered checkpoint (task-RL is *changing* the alteration).
|
| 16 |
+
- ``kl_to_base`` measures distance to the unaltered base. If
|
| 17 |
+
``kl_to_base`` SHRINKS while ``kl_to_altered_init`` grows, the alteration
|
| 18 |
+
is WASHING OUT (the policy drifts back toward base). If ``kl_to_base``
|
| 19 |
+
GROWS faster than ``kl_to_altered_init``, the alteration is being AMPLIFIED
|
| 20 |
+
(the policy moves further from base than the altered init already was) —
|
| 21 |
+
the ADR-013 amplification hypothesis, most likely on the SDPO channel.
|
| 22 |
+
|
| 23 |
+
Token-mean KL is used (mean over the masked answer+reasoning tokens), the
|
| 24 |
+
standard diagnostic convention. The math is the discrete KL between the two
|
| 25 |
+
softmax distributions implied by the logprob tensors:
|
| 26 |
+
|
| 27 |
+
KL(p || q) = sum_v p_v (log p_v - log q_v)
|
| 28 |
+
|
| 29 |
+
where ``p`` is the policy's per-token distribution. This is unit-testable on
|
| 30 |
+
toy tensors: KL(p || p) == 0, and KL grows monotonically as the policy moves.
|
| 31 |
+
"""
|
| 32 |
+
from __future__ import annotations
|
| 33 |
+
|
| 34 |
+
from typing import Any
|
| 35 |
+
|
| 36 |
+
import torch
|
| 37 |
+
|
| 38 |
+
__all__ = ["dual_kl_logger", "token_mean_kl"]
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def _as_log_probs(logprobs: torch.Tensor) -> torch.Tensor:
|
| 42 |
+
"""Normalize an input that may be raw logits OR already-log-probs to valid
|
| 43 |
+
log-probabilities along the last (vocab) dim.
|
| 44 |
+
|
| 45 |
+
We re-apply ``log_softmax`` defensively: it is idempotent on a genuine
|
| 46 |
+
log-prob tensor up to floating point (log_softmax of log-probs == log-probs
|
| 47 |
+
since they already sum-exp to 1), and converts raw logits correctly. This
|
| 48 |
+
makes the logger robust to either calling convention.
|
| 49 |
+
"""
|
| 50 |
+
return torch.log_softmax(logprobs.to(torch.float64), dim=-1)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def token_mean_kl(
|
| 54 |
+
policy_logprobs: torch.Tensor,
|
| 55 |
+
ref_logprobs: torch.Tensor,
|
| 56 |
+
mask: torch.Tensor | None = None,
|
| 57 |
+
) -> float:
|
| 58 |
+
"""Token-mean KL(policy || ref) over distributions on the last dim.
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
policy_logprobs: (..., V) logits or log-probs for the policy.
|
| 62 |
+
ref_logprobs: (..., V) logits or log-probs for the reference.
|
| 63 |
+
mask: optional (...,) mask of tokens to include (1/True = include). If
|
| 64 |
+
None, all tokens count.
|
| 65 |
+
|
| 66 |
+
Returns:
|
| 67 |
+
scalar token-mean KL as a python float (>= 0 up to float error).
|
| 68 |
+
"""
|
| 69 |
+
log_p = _as_log_probs(policy_logprobs)
|
| 70 |
+
log_q = _as_log_probs(ref_logprobs)
|
| 71 |
+
p = log_p.exp()
|
| 72 |
+
# per-token KL: sum over vocab of p * (log p - log q)
|
| 73 |
+
per_token = (p * (log_p - log_q)).sum(dim=-1) # (...,)
|
| 74 |
+
|
| 75 |
+
if mask is not None:
|
| 76 |
+
m = mask.to(per_token.dtype)
|
| 77 |
+
denom = m.sum()
|
| 78 |
+
if float(denom) == 0.0:
|
| 79 |
+
return 0.0
|
| 80 |
+
return float((per_token * m).sum() / denom)
|
| 81 |
+
return float(per_token.mean())
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def dual_kl_logger(
|
| 85 |
+
policy_logprobs: torch.Tensor,
|
| 86 |
+
altered_init_logprobs: torch.Tensor,
|
| 87 |
+
unaltered_base_logprobs: torch.Tensor,
|
| 88 |
+
mask: torch.Tensor | None = None,
|
| 89 |
+
**_: Any,
|
| 90 |
+
) -> dict[str, float]:
|
| 91 |
+
"""Compute the two diagnostic KLs for a step.
|
| 92 |
+
|
| 93 |
+
Args:
|
| 94 |
+
policy_logprobs: (..., V) policy logits/log-probs on the
|
| 95 |
+
answer+reasoning tokens.
|
| 96 |
+
altered_init_logprobs: (..., V) for the altered SFT init.
|
| 97 |
+
unaltered_base_logprobs:(..., V) for the unaltered base.
|
| 98 |
+
mask: optional (...,) token mask (answer+reasoning tokens to score).
|
| 99 |
+
|
| 100 |
+
Returns:
|
| 101 |
+
``{'kl_to_altered_init': float, 'kl_to_base': float}``.
|
| 102 |
+
"""
|
| 103 |
+
return {
|
| 104 |
+
"kl_to_altered_init": token_mean_kl(
|
| 105 |
+
policy_logprobs, altered_init_logprobs, mask
|
| 106 |
+
),
|
| 107 |
+
"kl_to_base": token_mean_kl(
|
| 108 |
+
policy_logprobs, unaltered_base_logprobs, mask
|
| 109 |
+
),
|
| 110 |
+
}
|
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""ladder.py — channel_ladder_configs (ADR-013, the experiment design).
|
| 2 |
+
|
| 3 |
+
The isolated-channel ladder REPLACES the old combined alpha=0.2/beta=0.4 recipe
|
| 4 |
+
(superseded; see docs/ALTERED_MINDS_TIE_IN.md Phase 3). Per ADR-013, a combined
|
| 5 |
+
run confounds four effects (task RL, self-distillation of altered reasoning,
|
| 6 |
+
frontier-teacher imitation, KL anchoring) and is scientifically uninterpretable.
|
| 7 |
+
Worse, SDPO against the altered model's OWN hint-conditioned forward pass can
|
| 8 |
+
AMPLIFY the distortion, so it is an experimental intervention, not a stabilizer.
|
| 9 |
+
|
| 10 |
+
The ladder isolates channels so each effect is attributable:
|
| 11 |
+
|
| 12 |
+
| Arm | alpha_sdpo | beta_replay | Purpose |
|
| 13 |
+
|-----|------------|-------------|----------------------------------|
|
| 14 |
+
| A0 | — | — | altered SFT, no RL (control) |
|
| 15 |
+
| A1 | 0.0 | 0.0 | GRPO-only baseline |
|
| 16 |
+
| A2 | 0.02 | 0.0 | +SDPO small (amplification probe)|
|
| 17 |
+
| A3 | 0.0 | 0.05 | +replay-DPO small (washout probe)|
|
| 18 |
+
| A4 | 0.02 | 0.05 | combined — ONLY after A1-A3 |
|
| 19 |
+
|
| 20 |
+
``kl_beta`` (KL-to-altered-init coef) = 0.02 for all RL arms. A0 is a sentinel
|
| 21 |
+
(no RL) so its alpha/beta/kl_beta are None.
|
| 22 |
+
"""
|
| 23 |
+
from __future__ import annotations
|
| 24 |
+
|
| 25 |
+
from typing import Any
|
| 26 |
+
|
| 27 |
+
__all__ = ["channel_ladder_configs", "LADDER_KL_BETA"]
|
| 28 |
+
|
| 29 |
+
#: KL-to-altered-init coefficient applied to every RL arm (A1-A4).
|
| 30 |
+
LADDER_KL_BETA = 0.02
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def channel_ladder_configs() -> list[dict[str, Any]]:
|
| 34 |
+
"""Return the ordered A0-A4 arm configs.
|
| 35 |
+
|
| 36 |
+
Each arm is a dict with keys: ``arm``, ``alpha_sdpo``, ``beta_replay``,
|
| 37 |
+
``kl_beta``, ``note``. A0 is the no-RL sentinel (alpha/beta/kl_beta = None).
|
| 38 |
+
|
| 39 |
+
A runner sweeps these with IDENTICAL seeds/prompts so any observed change in
|
| 40 |
+
the alteration signature is attributable to the single channel that arm
|
| 41 |
+
turns on relative to A1.
|
| 42 |
+
"""
|
| 43 |
+
return [
|
| 44 |
+
{
|
| 45 |
+
"arm": "A0",
|
| 46 |
+
"alpha_sdpo": None,
|
| 47 |
+
"beta_replay": None,
|
| 48 |
+
"kl_beta": None,
|
| 49 |
+
"note": (
|
| 50 |
+
"Control: altered SFT checkpoint, NO RL. Sentinel arm used to "
|
| 51 |
+
"anchor the pre-RL alteration signature."
|
| 52 |
+
),
|
| 53 |
+
},
|
| 54 |
+
{
|
| 55 |
+
"arm": "A1",
|
| 56 |
+
"alpha_sdpo": 0.0,
|
| 57 |
+
"beta_replay": 0.0,
|
| 58 |
+
"kl_beta": LADDER_KL_BETA,
|
| 59 |
+
"note": (
|
| 60 |
+
"GRPO-only baseline (both extra channels OFF). Isolates the "
|
| 61 |
+
"effect of task-driven RL alone on the alteration."
|
| 62 |
+
),
|
| 63 |
+
},
|
| 64 |
+
{
|
| 65 |
+
"arm": "A2",
|
| 66 |
+
"alpha_sdpo": 0.02,
|
| 67 |
+
"beta_replay": 0.0,
|
| 68 |
+
"kl_beta": LADDER_KL_BETA,
|
| 69 |
+
"note": (
|
| 70 |
+
"+SDPO small (amplification probe). SDPO ONLY vs A1: tests "
|
| 71 |
+
"whether self-distillation against the altered model's own "
|
| 72 |
+
"hint-conditioned forward pass AMPLIFIES the distortion."
|
| 73 |
+
),
|
| 74 |
+
},
|
| 75 |
+
{
|
| 76 |
+
"arm": "A3",
|
| 77 |
+
"alpha_sdpo": 0.0,
|
| 78 |
+
"beta_replay": 0.05,
|
| 79 |
+
"kl_beta": LADDER_KL_BETA,
|
| 80 |
+
"note": (
|
| 81 |
+
"+replay-DPO small (washout probe). Trace-replay-DPO ONLY vs "
|
| 82 |
+
"A1: tests whether frontier-teacher disagreement WASHES OUT the "
|
| 83 |
+
"alteration toward base."
|
| 84 |
+
),
|
| 85 |
+
},
|
| 86 |
+
{
|
| 87 |
+
"arm": "A4",
|
| 88 |
+
"alpha_sdpo": 0.02,
|
| 89 |
+
"beta_replay": 0.05,
|
| 90 |
+
"kl_beta": LADDER_KL_BETA,
|
| 91 |
+
"note": (
|
| 92 |
+
"Combined — run ONLY after A1-A3 are interpretable. Confounds "
|
| 93 |
+
"channels by design; meaningful only as a capstone once the "
|
| 94 |
+
"isolated arms are understood."
|
| 95 |
+
),
|
| 96 |
+
},
|
| 97 |
+
]
|
|
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""reward.py — MMLUFormatReward (ADR-013, framework-side, generic).
|
| 2 |
+
|
| 3 |
+
A structured-answer reward for RL on MMLU-style multiple-choice tasks. It scores
|
| 4 |
+
ONLY the final answer letter + format validity — never the rationale's style or
|
| 5 |
+
content. This is deliberate: the north-star use case (ADR-013) drives RL on a
|
| 6 |
+
*personality-altered* model, and rewarding "persuasive" rationale would reward
|
| 7 |
+
the very cognitive-distortion signature we are trying to measure rather than
|
| 8 |
+
distort the reward toward it.
|
| 9 |
+
|
| 10 |
+
Scoring per completion:
|
| 11 |
+
+1.0 final answer parses and equals the gold letter
|
| 12 |
+
0.0 final answer parses but is wrong
|
| 13 |
+
-0.2 no parseable final-answer marker (unparseable)
|
| 14 |
+
-0.1 multiple DISTINCT final-answer markers present (format hacking)
|
| 15 |
+
-len_penalty small penalty past a rationale character cap
|
| 16 |
+
|
| 17 |
+
Parsing accepts (case-insensitive, last match wins for the canonical letter):
|
| 18 |
+
- ``Answer: X`` (X in A-D)
|
| 19 |
+
- JSON ``{"answer": "X"}``
|
| 20 |
+
|
| 21 |
+
Exploit detection: ``MMLUFormatReward`` keeps a running count of chosen letters
|
| 22 |
+
(``option_distribution``) so an "always C" / option-prior exploit is detectable
|
| 23 |
+
by inspecting that distribution after a run. A companion ``randomize_options``
|
| 24 |
+
helper shuffles option order with an original->shuffled label remap so the
|
| 25 |
+
training data itself can be de-biased.
|
| 26 |
+
"""
|
| 27 |
+
from __future__ import annotations
|
| 28 |
+
|
| 29 |
+
import json
|
| 30 |
+
import re
|
| 31 |
+
from collections import Counter
|
| 32 |
+
from dataclasses import dataclass, field
|
| 33 |
+
from typing import Any
|
| 34 |
+
|
| 35 |
+
__all__ = ["MMLUFormatReward", "randomize_options", "parse_final_answer"]
|
| 36 |
+
|
| 37 |
+
_VALID_LETTERS = ("A", "B", "C", "D")
|
| 38 |
+
|
| 39 |
+
# ``Answer: X`` — tolerant of whitespace, optional markdown bold/asterisks.
|
| 40 |
+
_ANSWER_RE = re.compile(r"answer\s*[:\-]\s*\*{0,2}([A-D])\b", re.IGNORECASE)
|
| 41 |
+
# JSON ``{"answer": "X"}`` — extract the value of an "answer" key.
|
| 42 |
+
_JSON_ANSWER_RE = re.compile(
|
| 43 |
+
r'["\']answer["\']\s*:\s*["\']([A-D])["\']', re.IGNORECASE
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def _find_markers(text: str) -> list[str]:
|
| 48 |
+
"""Return ALL final-answer letters found (uppercased), in order of appearance.
|
| 49 |
+
|
| 50 |
+
Used both to pick the canonical answer (last match wins) and to detect the
|
| 51 |
+
multiple-distinct-markers format-hacking case.
|
| 52 |
+
"""
|
| 53 |
+
markers: list[tuple[int, str]] = []
|
| 54 |
+
for m in _ANSWER_RE.finditer(text or ""):
|
| 55 |
+
markers.append((m.start(), m.group(1).upper()))
|
| 56 |
+
for m in _JSON_ANSWER_RE.finditer(text or ""):
|
| 57 |
+
markers.append((m.start(), m.group(1).upper()))
|
| 58 |
+
markers.sort(key=lambda p: p[0])
|
| 59 |
+
return [letter for _, letter in markers]
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def parse_final_answer(completion: str) -> tuple[str | None, int]:
|
| 63 |
+
"""Parse the final answer letter from a completion.
|
| 64 |
+
|
| 65 |
+
Returns ``(letter_or_None, n_distinct_markers)``. ``letter`` is the LAST
|
| 66 |
+
marker found (last match wins). ``n_distinct_markers`` counts DISTINCT
|
| 67 |
+
letters across all markers (so two ``Answer: C`` are not penalized, but
|
| 68 |
+
``Answer: A ... Answer: B`` is).
|
| 69 |
+
"""
|
| 70 |
+
markers = _find_markers(completion)
|
| 71 |
+
if not markers:
|
| 72 |
+
return None, 0
|
| 73 |
+
distinct = len(set(markers))
|
| 74 |
+
return markers[-1], distinct
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
@dataclass
|
| 78 |
+
class MMLUFormatReward:
|
| 79 |
+
"""Callable reward_fn(prompts, completions, *, answers, **kwargs) -> list[float].
|
| 80 |
+
|
| 81 |
+
Args:
|
| 82 |
+
rationale_char_cap: completions longer than this incur a small length
|
| 83 |
+
penalty (``length_penalty_per_char`` per char past the cap). Caps
|
| 84 |
+
verbosity without scoring rationale content.
|
| 85 |
+
length_penalty_per_char: per-character penalty past the cap.
|
| 86 |
+
correct_reward / wrong_reward / unparseable_reward /
|
| 87 |
+
multiple_answers_reward: the scalar rewards for each outcome.
|
| 88 |
+
|
| 89 |
+
Side effect: ``option_distribution`` (a Counter over chosen letters) and
|
| 90 |
+
``n_scored`` accumulate across calls so an "always C" exploit is detectable
|
| 91 |
+
via ``exploit_report()``.
|
| 92 |
+
"""
|
| 93 |
+
|
| 94 |
+
rationale_char_cap: int = 512
|
| 95 |
+
length_penalty_per_char: float = 0.001
|
| 96 |
+
correct_reward: float = 1.0
|
| 97 |
+
wrong_reward: float = 0.0
|
| 98 |
+
unparseable_reward: float = -0.2
|
| 99 |
+
multiple_answers_reward: float = -0.1
|
| 100 |
+
option_distribution: Counter = field(default_factory=Counter)
|
| 101 |
+
n_scored: int = 0
|
| 102 |
+
|
| 103 |
+
def __call__(
|
| 104 |
+
self,
|
| 105 |
+
prompts: Any = None,
|
| 106 |
+
completions: list[str] | None = None,
|
| 107 |
+
*,
|
| 108 |
+
answers: list[str] | None = None,
|
| 109 |
+
**kwargs: Any,
|
| 110 |
+
) -> list[float]:
|
| 111 |
+
"""Score a batch of completions against gold ``answers`` (letters A-D).
|
| 112 |
+
|
| 113 |
+
``prompts`` is accepted for the TRL reward-fn signature but unused
|
| 114 |
+
(we score the completion text only). ``answers`` is required.
|
| 115 |
+
"""
|
| 116 |
+
if completions is None:
|
| 117 |
+
completions = []
|
| 118 |
+
if answers is None:
|
| 119 |
+
raise ValueError(
|
| 120 |
+
"MMLUFormatReward requires `answers` (the gold letters, one per "
|
| 121 |
+
"completion). Pass via reward_fn(..., answers=[...])."
|
| 122 |
+
)
|
| 123 |
+
if len(answers) != len(completions):
|
| 124 |
+
raise ValueError(
|
| 125 |
+
f"answers/completions length mismatch: {len(answers)} vs "
|
| 126 |
+
f"{len(completions)}."
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
rewards: list[float] = []
|
| 130 |
+
for completion, gold in zip(completions, answers):
|
| 131 |
+
rewards.append(self._score_one(completion, gold))
|
| 132 |
+
return rewards
|
| 133 |
+
|
| 134 |
+
def _score_one(self, completion: str, gold: str) -> float:
|
| 135 |
+
letter, n_distinct = parse_final_answer(completion)
|
| 136 |
+
self.n_scored += 1
|
| 137 |
+
|
| 138 |
+
if letter is None:
|
| 139 |
+
# Unparseable: no usable final-answer marker. Length penalty does
|
| 140 |
+
# not apply (we never even parsed a letter to reward/penalize).
|
| 141 |
+
return self.unparseable_reward
|
| 142 |
+
|
| 143 |
+
# Log the chosen letter for exploit detection (always-C etc.).
|
| 144 |
+
self.option_distribution[letter] += 1
|
| 145 |
+
|
| 146 |
+
if n_distinct > 1:
|
| 147 |
+
# Multiple DISTINCT markers — format hacking. Penalize regardless
|
| 148 |
+
# of correctness (the model is hedging / gaming the parser).
|
| 149 |
+
base = self.multiple_answers_reward
|
| 150 |
+
elif gold is not None and letter == str(gold).strip().upper():
|
| 151 |
+
base = self.correct_reward
|
| 152 |
+
else:
|
| 153 |
+
base = self.wrong_reward
|
| 154 |
+
|
| 155 |
+
return base - self._length_penalty(completion)
|
| 156 |
+
|
| 157 |
+
def _length_penalty(self, completion: str) -> float:
|
| 158 |
+
over = max(0, len(completion or "") - self.rationale_char_cap)
|
| 159 |
+
return self.length_penalty_per_char * over
|
| 160 |
+
|
| 161 |
+
# ------------------------------------------------------------------
|
| 162 |
+
# Exploit detection
|
| 163 |
+
# ------------------------------------------------------------------
|
| 164 |
+
def exploit_report(self) -> dict[str, Any]:
|
| 165 |
+
"""Summarize the chosen-letter distribution so an option-prior exploit
|
| 166 |
+
(e.g. "always C") is detectable.
|
| 167 |
+
|
| 168 |
+
Returns a dict with the raw counts, the most common letter, and its
|
| 169 |
+
fraction of all parsed answers. A healthy run is ~uniform over A-D; a
|
| 170 |
+
fraction near 1.0 for a single letter is the exploit signature.
|
| 171 |
+
"""
|
| 172 |
+
total = sum(self.option_distribution.values())
|
| 173 |
+
if total == 0:
|
| 174 |
+
return {
|
| 175 |
+
"counts": {},
|
| 176 |
+
"total_parsed": 0,
|
| 177 |
+
"most_common": None,
|
| 178 |
+
"max_fraction": 0.0,
|
| 179 |
+
}
|
| 180 |
+
letter, count = self.option_distribution.most_common(1)[0]
|
| 181 |
+
return {
|
| 182 |
+
"counts": dict(self.option_distribution),
|
| 183 |
+
"total_parsed": total,
|
| 184 |
+
"most_common": letter,
|
| 185 |
+
"max_fraction": count / total,
|
| 186 |
+
}
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def randomize_options(
|
| 190 |
+
item: dict[str, Any], seed: int
|
| 191 |
+
) -> tuple[dict[str, Any], dict[str, str]]:
|
| 192 |
+
"""Shuffle the multiple-choice option order, tracking original->shuffled letters.
|
| 193 |
+
|
| 194 |
+
Args:
|
| 195 |
+
item: a dict with ``options`` (list[str], A-first ordering) and
|
| 196 |
+
``answer`` (the gold letter, A-D). Other keys are passed through.
|
| 197 |
+
seed: deterministic RNG seed for the shuffle.
|
| 198 |
+
|
| 199 |
+
Returns:
|
| 200 |
+
``(shuffled_item, label_remap)`` where ``shuffled_item`` has the options
|
| 201 |
+
reordered and its ``answer`` updated to the gold option's NEW letter, and
|
| 202 |
+
``label_remap`` maps each ORIGINAL letter -> its NEW (shuffled) letter.
|
| 203 |
+
|
| 204 |
+
This de-biases an option-prior exploit at the data level: if the gold answer
|
| 205 |
+
is no longer correlated with a fixed position, "always C" stops working.
|
| 206 |
+
"""
|
| 207 |
+
import random
|
| 208 |
+
|
| 209 |
+
options = list(item.get("options", []))
|
| 210 |
+
n = len(options)
|
| 211 |
+
if n == 0:
|
| 212 |
+
return dict(item), {}
|
| 213 |
+
orig_letters = [chr(ord("A") + i) for i in range(n)]
|
| 214 |
+
|
| 215 |
+
rng = random.Random(seed)
|
| 216 |
+
perm = list(range(n))
|
| 217 |
+
rng.shuffle(perm)
|
| 218 |
+
# perm[new_pos] = old_pos => option at new_pos is the old option perm[new_pos]
|
| 219 |
+
shuffled_options = [options[perm[new]] for new in range(n)]
|
| 220 |
+
|
| 221 |
+
# original letter -> new letter: old index `perm[new]` moved to position `new`.
|
| 222 |
+
label_remap: dict[str, str] = {}
|
| 223 |
+
for new_pos, old_pos in enumerate(perm):
|
| 224 |
+
label_remap[orig_letters[old_pos]] = orig_letters[new_pos]
|
| 225 |
+
|
| 226 |
+
shuffled_item = dict(item)
|
| 227 |
+
shuffled_item["options"] = shuffled_options
|
| 228 |
+
gold = str(item.get("answer", "")).strip().upper()
|
| 229 |
+
if gold in label_remap:
|
| 230 |
+
shuffled_item["answer"] = label_remap[gold]
|
| 231 |
+
return shuffled_item, label_remap
|
|
File without changes
|
|
@@ -0,0 +1,362 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for the ADR-013 altered_minds integration glue + the B4 SDPO-fires proof.
|
| 2 |
+
|
| 3 |
+
Covers the ADR-013 acceptance gate:
|
| 4 |
+
- MMLUFormatReward: correct→+1, wrong→0, unparseable→−0.2, multiple→−0.1,
|
| 5 |
+
length-penalty, and an "always C" option-prior exploit is DETECTABLE via the
|
| 6 |
+
logged option distribution. Rationale style is NOT scored.
|
| 7 |
+
- dual_kl_logger: KL(p‖p)==0 and KL grows as the policy moves.
|
| 8 |
+
- channel_ladder_configs: A1 both off, A2 SDPO-only, A3 replay-only.
|
| 9 |
+
- B4: the SDPO channel actually FIRES (NONZERO loss) with REAL collator-built
|
| 10 |
+
alignment indices. See the module docstring on test_b4_* for the honest
|
| 11 |
+
stub-vs-real note.
|
| 12 |
+
|
| 13 |
+
All CPU-only and fast (stub tokenizer + tiny model — no model download).
|
| 14 |
+
"""
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import pytest
|
| 18 |
+
import torch
|
| 19 |
+
|
| 20 |
+
from composer_replication.integrations.altered_minds import (
|
| 21 |
+
MMLUFormatReward,
|
| 22 |
+
channel_ladder_configs,
|
| 23 |
+
dual_kl_logger,
|
| 24 |
+
randomize_options,
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# ===========================================================================
|
| 29 |
+
# MMLUFormatReward
|
| 30 |
+
# ===========================================================================
|
| 31 |
+
|
| 32 |
+
def test_reward_correct_wrong_unparseable_multiple():
|
| 33 |
+
r = MMLUFormatReward()
|
| 34 |
+
completions = [
|
| 35 |
+
"Reasoning blah. Answer: B", # correct
|
| 36 |
+
"I think it's Answer: A", # wrong (gold C)
|
| 37 |
+
"no marker here at all", # unparseable
|
| 38 |
+
"Answer: A then actually Answer: D", # multiple distinct
|
| 39 |
+
'{"answer": "C"}', # JSON correct
|
| 40 |
+
]
|
| 41 |
+
answers = ["B", "C", "B", "A", "C"]
|
| 42 |
+
out = r(prompts=None, completions=completions, answers=answers)
|
| 43 |
+
assert out[0] == pytest.approx(1.0) # correct
|
| 44 |
+
assert out[1] == pytest.approx(0.0) # wrong
|
| 45 |
+
assert out[2] == pytest.approx(-0.2) # unparseable
|
| 46 |
+
assert out[3] == pytest.approx(-0.1) # multiple distinct markers
|
| 47 |
+
assert out[4] == pytest.approx(1.0) # JSON correct
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def test_reward_last_match_wins_same_letter_not_penalized():
|
| 51 |
+
"""Two markers of the SAME letter is not 'multiple distinct' — last wins."""
|
| 52 |
+
r = MMLUFormatReward()
|
| 53 |
+
out = r(completions=["Answer: C ... so my final Answer: C"], answers=["C"])
|
| 54 |
+
assert out[0] == pytest.approx(1.0)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def test_reward_case_insensitive_and_json_variants():
|
| 58 |
+
r = MMLUFormatReward()
|
| 59 |
+
out = r(
|
| 60 |
+
completions=["answer: d", '{"answer":"a"}'],
|
| 61 |
+
answers=["D", "A"],
|
| 62 |
+
)
|
| 63 |
+
assert out[0] == pytest.approx(1.0)
|
| 64 |
+
assert out[1] == pytest.approx(1.0)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def test_reward_length_penalty_only_past_cap():
|
| 68 |
+
"""A correct-but-long completion is penalized by ~0.001/char past the cap;
|
| 69 |
+
a short one is not. Rationale CONTENT is never scored — only length."""
|
| 70 |
+
r = MMLUFormatReward(rationale_char_cap=20, length_penalty_per_char=0.001)
|
| 71 |
+
short = "Answer: B" # under cap
|
| 72 |
+
long = "x" * 120 + " Answer: B" # ~130 chars, 110 over cap
|
| 73 |
+
out = r(completions=[short, long], answers=["B", "B"])
|
| 74 |
+
assert out[0] == pytest.approx(1.0)
|
| 75 |
+
# 130 - 20 = 110 over => penalty 0.110; reward 1.0 - 0.110
|
| 76 |
+
assert out[1] < 1.0
|
| 77 |
+
assert out[1] == pytest.approx(1.0 - 0.001 * (len(long) - 20))
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def test_reward_always_C_exploit_is_detectable():
|
| 81 |
+
"""An 'always C' policy that happens to be right when gold==C scores well on
|
| 82 |
+
those items, but the logged option distribution reveals the exploit."""
|
| 83 |
+
r = MMLUFormatReward()
|
| 84 |
+
completions = [f"Answer: C" for _ in range(10)]
|
| 85 |
+
golds = ["C", "A", "B", "C", "D", "A", "C", "B", "C", "D"]
|
| 86 |
+
r(completions=completions, answers=golds)
|
| 87 |
+
report = r.exploit_report()
|
| 88 |
+
assert report["most_common"] == "C"
|
| 89 |
+
# Every parsed answer was C => fraction 1.0 — the exploit signature.
|
| 90 |
+
assert report["max_fraction"] == pytest.approx(1.0)
|
| 91 |
+
assert report["counts"] == {"C": 10}
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def test_reward_requires_answers():
|
| 95 |
+
r = MMLUFormatReward()
|
| 96 |
+
with pytest.raises(ValueError, match="requires `answers`"):
|
| 97 |
+
r(completions=["Answer: A"])
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def test_randomize_options_tracks_label_remap_and_updates_gold():
|
| 101 |
+
item = {"question": "q", "options": ["w", "x", "y", "z"], "answer": "A"}
|
| 102 |
+
shuffled, remap = randomize_options(item, seed=7)
|
| 103 |
+
# All four letters map to four distinct new letters (a permutation).
|
| 104 |
+
assert sorted(remap.keys()) == ["A", "B", "C", "D"]
|
| 105 |
+
assert sorted(remap.values()) == ["A", "B", "C", "D"]
|
| 106 |
+
# The gold option's text ("w", originally A) now lives at its remapped letter.
|
| 107 |
+
new_gold_letter = shuffled["answer"]
|
| 108 |
+
new_gold_idx = ord(new_gold_letter) - ord("A")
|
| 109 |
+
assert shuffled["options"][new_gold_idx] == "w"
|
| 110 |
+
assert remap["A"] == new_gold_letter
|
| 111 |
+
# Determinism.
|
| 112 |
+
shuffled2, remap2 = randomize_options(item, seed=7)
|
| 113 |
+
assert remap == remap2 and shuffled["options"] == shuffled2["options"]
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
# ===========================================================================
|
| 117 |
+
# dual_kl_logger
|
| 118 |
+
# ===========================================================================
|
| 119 |
+
|
| 120 |
+
def test_dual_kl_self_is_zero():
|
| 121 |
+
"""KL(p‖p) == 0 for both diagnostics."""
|
| 122 |
+
logits = torch.randn(2, 5, 16)
|
| 123 |
+
out = dual_kl_logger(logits, logits, logits)
|
| 124 |
+
assert out["kl_to_altered_init"] == pytest.approx(0.0, abs=1e-6)
|
| 125 |
+
assert out["kl_to_base"] == pytest.approx(0.0, abs=1e-6)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def test_dual_kl_grows_as_policy_moves():
|
| 129 |
+
"""As the policy distribution moves further from a fixed reference, the KL
|
| 130 |
+
grows monotonically. Both diagnostics are non-negative."""
|
| 131 |
+
torch.manual_seed(0)
|
| 132 |
+
ref = torch.randn(1, 4, 16)
|
| 133 |
+
base = torch.randn(1, 4, 16)
|
| 134 |
+
|
| 135 |
+
near = ref + 0.1 * torch.randn_like(ref)
|
| 136 |
+
far = ref + 2.0 * torch.randn_like(ref)
|
| 137 |
+
|
| 138 |
+
kl_near = dual_kl_logger(near, ref, base)["kl_to_altered_init"]
|
| 139 |
+
kl_far = dual_kl_logger(far, ref, base)["kl_to_altered_init"]
|
| 140 |
+
|
| 141 |
+
assert kl_near >= -1e-9
|
| 142 |
+
assert kl_far > kl_near, f"KL should grow as policy moves: {kl_near} -> {kl_far}"
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def test_dual_kl_mask_restricts_tokens():
|
| 146 |
+
"""A token mask restricts the mean to the masked answer+reasoning tokens."""
|
| 147 |
+
torch.manual_seed(1)
|
| 148 |
+
policy = torch.randn(1, 4, 8)
|
| 149 |
+
ref = torch.randn(1, 4, 8)
|
| 150 |
+
base = torch.randn(1, 4, 8)
|
| 151 |
+
mask = torch.tensor([[1, 1, 0, 0]])
|
| 152 |
+
out = dual_kl_logger(policy, ref, base, mask=mask)
|
| 153 |
+
# Masked-all-zero => 0.0 (guarded), nonzero mask => finite non-negative.
|
| 154 |
+
assert out["kl_to_altered_init"] >= -1e-9
|
| 155 |
+
zero = dual_kl_logger(policy, ref, base, mask=torch.zeros(1, 4))
|
| 156 |
+
assert zero["kl_to_altered_init"] == 0.0
|
| 157 |
+
assert zero["kl_to_base"] == 0.0
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
# ===========================================================================
|
| 161 |
+
# channel_ladder_configs
|
| 162 |
+
# ===========================================================================
|
| 163 |
+
|
| 164 |
+
def test_ladder_arms_and_order():
|
| 165 |
+
arms = channel_ladder_configs()
|
| 166 |
+
assert [a["arm"] for a in arms] == ["A0", "A1", "A2", "A3", "A4"]
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def test_ladder_a0_is_no_rl_sentinel():
|
| 170 |
+
a0 = channel_ladder_configs()[0]
|
| 171 |
+
assert a0["arm"] == "A0"
|
| 172 |
+
assert a0["alpha_sdpo"] is None
|
| 173 |
+
assert a0["beta_replay"] is None
|
| 174 |
+
assert a0["kl_beta"] is None
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def test_ladder_a1_both_off():
|
| 178 |
+
a1 = channel_ladder_configs()[1]
|
| 179 |
+
assert a1["alpha_sdpo"] == 0.0
|
| 180 |
+
assert a1["beta_replay"] == 0.0
|
| 181 |
+
assert a1["kl_beta"] == 0.02
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def test_ladder_a2_sdpo_only():
|
| 185 |
+
a2 = channel_ladder_configs()[2]
|
| 186 |
+
assert a2["alpha_sdpo"] == 0.02
|
| 187 |
+
assert a2["beta_replay"] == 0.0
|
| 188 |
+
assert a2["kl_beta"] == 0.02
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def test_ladder_a3_replay_only():
|
| 192 |
+
a3 = channel_ladder_configs()[3]
|
| 193 |
+
assert a3["alpha_sdpo"] == 0.0
|
| 194 |
+
assert a3["beta_replay"] == 0.05
|
| 195 |
+
assert a3["kl_beta"] == 0.02
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def test_ladder_a4_combined():
|
| 199 |
+
a4 = channel_ladder_configs()[4]
|
| 200 |
+
assert a4["alpha_sdpo"] == 0.02
|
| 201 |
+
assert a4["beta_replay"] == 0.05
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
# ===========================================================================
|
| 205 |
+
# B4 — the SDPO channel actually FIRES (NONZERO) with REAL collator indices
|
| 206 |
+
# ===========================================================================
|
| 207 |
+
#
|
| 208 |
+
# HONEST NOTE ON STUB-VS-REAL (ADR-013 B4 acceptance):
|
| 209 |
+
#
|
| 210 |
+
# This proof uses the same TinyLM stub pattern as
|
| 211 |
+
# trainer/tests/test_sdpo_alignment_indices.py, NOT a real Qwen checkpoint
|
| 212 |
+
# (kept offline/CPU and deterministic). The alignment indices are REAL: they are
|
| 213 |
+
# built by the production ComposerDataCollator from a trace that HAS an error
|
| 214 |
+
# turn (so ctx_teacher_input_ids + student/teacher_response_idx are genuinely
|
| 215 |
+
# emitted by the shipped collator, exactly as in a real run).
|
| 216 |
+
#
|
| 217 |
+
# Why we must perturb the student tokens to get a NONZERO loss: the collator's
|
| 218 |
+
# placeholder-alignment trick makes student and teacher carry the SAME token ids
|
| 219 |
+
# at the SAME absolute positions at valid aligned indices, so a deterministic
|
| 220 |
+
# stub yields JSD≈0 there (the CORRECT answer for a perfectly-aligned identical
|
| 221 |
+
# model — see that test's gate-3 note). To prove the channel genuinely GATHERS
|
| 222 |
+
# the aligned positions and computes nonzero divergence, we make the student's
|
| 223 |
+
# input_ids DIFFER from the teacher's at exactly the aligned response positions
|
| 224 |
+
# — this mimics the hint actually changing the recovery tokens (the real-world
|
| 225 |
+
# case where SDPO has a signal to distill). With a position-dependent stub,
|
| 226 |
+
# different aligned token ids => different logits => provably NONZERO JSD on a
|
| 227 |
+
# grad path, through the real collator-built indices.
|
| 228 |
+
|
| 229 |
+
from composer_replication.trainer.data_collator import ( # noqa: E402
|
| 230 |
+
CollatorConfig,
|
| 231 |
+
ComposerDataCollator,
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
class _StubTok:
|
| 236 |
+
"""Word-level deterministic tokenizer; apply_chat_template space-joins."""
|
| 237 |
+
|
| 238 |
+
pad_token_id = 0
|
| 239 |
+
|
| 240 |
+
def __init__(self) -> None:
|
| 241 |
+
self._v: dict[str, int] = {"<pad>": 0, "<bos>": 1, "<eos>": 2}
|
| 242 |
+
|
| 243 |
+
def _id(self, w: str) -> int:
|
| 244 |
+
if w not in self._v:
|
| 245 |
+
self._v[w] = len(self._v)
|
| 246 |
+
return self._v[w]
|
| 247 |
+
|
| 248 |
+
def __call__(self, text, **_k):
|
| 249 |
+
return {"input_ids": [self._id(w) for w in text.split()] if text else []}
|
| 250 |
+
|
| 251 |
+
def apply_chat_template(self, messages, tokenize=True, **_k): # noqa: ARG002
|
| 252 |
+
return [self._id(w) for w in " ".join(m.get("content", "") for m in messages).split()]
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
class _TinyLM(torch.nn.Module):
|
| 256 |
+
"""Position-dependent minimal model: model(input_ids=...).logits."""
|
| 257 |
+
|
| 258 |
+
def __init__(self, vocab: int = 64, hidden: int = 8, max_pos: int = 512):
|
| 259 |
+
super().__init__()
|
| 260 |
+
torch.manual_seed(0)
|
| 261 |
+
self.embed = torch.nn.Embedding(vocab, hidden)
|
| 262 |
+
self.pos = torch.nn.Embedding(max_pos, hidden)
|
| 263 |
+
self.head = torch.nn.Linear(hidden, vocab)
|
| 264 |
+
|
| 265 |
+
def forward(self, input_ids: torch.Tensor):
|
| 266 |
+
T = input_ids.size(1)
|
| 267 |
+
positions = torch.arange(T, device=input_ids.device).unsqueeze(0)
|
| 268 |
+
h = self.embed(input_ids) + self.pos(positions)
|
| 269 |
+
|
| 270 |
+
class _Out:
|
| 271 |
+
pass
|
| 272 |
+
|
| 273 |
+
out = _Out()
|
| 274 |
+
out.logits = self.head(h)
|
| 275 |
+
return out
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
def _hint_gen(_kind, _meta):
|
| 279 |
+
return "HINT search before reading"
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
def _error_trace(trace_id: str, recovery: str = "let me use a real tool instead now"):
|
| 283 |
+
return {
|
| 284 |
+
"trace_id": trace_id,
|
| 285 |
+
"turns": [
|
| 286 |
+
{"role": "user", "content": "do the task now"},
|
| 287 |
+
{"role": "user", "content": "tool not found error occurred"},
|
| 288 |
+
{
|
| 289 |
+
"role": "assistant",
|
| 290 |
+
"content": recovery,
|
| 291 |
+
"tool_error": "tool_not_found",
|
| 292 |
+
"error_meta": {},
|
| 293 |
+
},
|
| 294 |
+
],
|
| 295 |
+
"final_reward": 0.0,
|
| 296 |
+
}
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
def _make_sdpo_trainer(alpha_sdpo: float):
|
| 300 |
+
from composer_replication.trainer.composer_trainer import ComposerReplicationTrainer
|
| 301 |
+
|
| 302 |
+
obj = ComposerReplicationTrainer.__new__(ComposerReplicationTrainer)
|
| 303 |
+
obj.alpha_sdpo = alpha_sdpo
|
| 304 |
+
obj.sdpo_jsd_beta = 0.5
|
| 305 |
+
obj.sdpo_temperature = 1.0
|
| 306 |
+
obj.sdpo_token_clip = None
|
| 307 |
+
obj.strict_sdpo_alignment = True # production default
|
| 308 |
+
return obj
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
def test_b4_sdpo_fires_nonzero_with_real_collator_indices():
|
| 312 |
+
"""B4: with REAL collator-built alignment indices and the student tokens
|
| 313 |
+
differing from the teacher at the aligned response positions (hint changed
|
| 314 |
+
the recovery tokens), the SDPO channel gathers those positions and produces
|
| 315 |
+
a NONZERO JSD on a grad path — proving the channel actually FIRES."""
|
| 316 |
+
tok = _StubTok()
|
| 317 |
+
cfg = CollatorConfig(hint_generator=_hint_gen, enable_replay_dpo=False)
|
| 318 |
+
collator = ComposerDataCollator(tokenizer=tok, config=cfg)
|
| 319 |
+
batch = collator([_error_trace("b4-fires")])
|
| 320 |
+
|
| 321 |
+
# Sanity: the collator genuinely emitted error-site teacher context + indices.
|
| 322 |
+
assert batch["ctx_teacher_input_ids"].numel() > 0
|
| 323 |
+
s_idx = batch["student_response_idx"]
|
| 324 |
+
t_idx = batch["teacher_response_idx"]
|
| 325 |
+
s_valid = batch["student_response_valid"]
|
| 326 |
+
assert int(s_valid.sum()) > 0, "no valid aligned positions — collator emitted nothing"
|
| 327 |
+
|
| 328 |
+
# Perturb the STUDENT tokens at the aligned response positions so they differ
|
| 329 |
+
# from the teacher's tokens there (the hint changed the recovery tokens). We
|
| 330 |
+
# keep the REAL collator-built indices; only the student input_ids change.
|
| 331 |
+
student_ids = batch["input_ids"].clone()
|
| 332 |
+
vocab_ceiling = int(
|
| 333 |
+
max(batch["input_ids"].max(), batch["ctx_teacher_input_ids"].max())
|
| 334 |
+
) + 8
|
| 335 |
+
for b in range(s_idx.shape[0]):
|
| 336 |
+
for k in range(s_idx.shape[1]):
|
| 337 |
+
if bool(s_valid[b, k]):
|
| 338 |
+
pos = int(s_idx[b, k])
|
| 339 |
+
# bump to a different, in-vocab token id (deterministic).
|
| 340 |
+
student_ids[b, pos] = (int(student_ids[b, pos]) + 3) % vocab_ceiling
|
| 341 |
+
batch["input_ids"] = student_ids
|
| 342 |
+
|
| 343 |
+
model = _TinyLM(vocab=max(vocab_ceiling, 8))
|
| 344 |
+
obj = _make_sdpo_trainer(alpha_sdpo=0.02) # A2 config (SDPO-only small)
|
| 345 |
+
|
| 346 |
+
loss = obj._compute_sdpo_loss(model, batch)
|
| 347 |
+
val = float(loss.detach())
|
| 348 |
+
|
| 349 |
+
assert val == val and val not in (float("inf"), float("-inf")), "loss not finite"
|
| 350 |
+
assert loss.requires_grad, "SDPO loss must be on a grad path"
|
| 351 |
+
assert val > 1e-6, (
|
| 352 |
+
f"SDPO channel did not fire: JSD={val} (expected NONZERO once the "
|
| 353 |
+
"aligned student/teacher tokens differ). The channel must gather the "
|
| 354 |
+
"real collator indices and compute a positive divergence."
|
| 355 |
+
)
|
| 356 |
+
|
| 357 |
+
# Prove it is differentiable end-to-end: backward populates a real gradient.
|
| 358 |
+
(obj.alpha_sdpo * loss).backward()
|
| 359 |
+
grad_norm = sum(
|
| 360 |
+
float(p.grad.norm()) for p in model.parameters() if p.grad is not None
|
| 361 |
+
)
|
| 362 |
+
assert grad_norm > 0.0, "no gradient flowed from the SDPO loss into the model"
|
|
@@ -79,16 +79,43 @@ Fits inside the user's existing $400 altered-minds budget.
|
|
| 79 |
|
| 80 |
### Phase 3 — GRPO with the framework
|
| 81 |
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
|
| 93 |
### Phase 4 — re-evaluate
|
| 94 |
|
|
|
|
| 79 |
|
| 80 |
### Phase 3 — GRPO with the framework
|
| 81 |
|
| 82 |
+
> **⚠️ SUPERSEDED by [ADR-013](adrs/ADR-013-lma-integration-channel-ladder.md).**
|
| 83 |
+
> The original all-channels-on combined recipe (α=0.2, β=0.4) is **not used**.
|
| 84 |
+
> A cross-family research critique (2026-05-29) found a combined-first run
|
| 85 |
+
> **scientifically uninterpretable**: it confounds four effects (task RL,
|
| 86 |
+
> self-distillation of altered reasoning, frontier-teacher imitation, KL
|
| 87 |
+
> anchoring), so any observed change in the alteration signature cannot be
|
| 88 |
+
> attributed to a channel. Worse, **SDPO against the altered model's own
|
| 89 |
+
> hint-conditioned forward pass is the channel most likely to AMPLIFY the
|
| 90 |
+
> distortion** (teacher == student-family; if hints add no independent
|
| 91 |
+
> information, the optimum is to imitate the altered conditional distribution,
|
| 92 |
+
> sharpening a soft bias into a hard preference). SDPO here is therefore an
|
| 93 |
+
> *experimental intervention*, not a benign stabilizer.
|
| 94 |
+
|
| 95 |
+
**Use the isolated-channel ladder (ADR-013) instead** — sweep arms A0–A4 with
|
| 96 |
+
identical seeds/prompts so each channel's effect is attributable:
|
| 97 |
+
|
| 98 |
+
| Arm | alpha_sdpo | beta_replay | Purpose |
|
| 99 |
+
|---|---|---|---|
|
| 100 |
+
| A0 | — | — | altered SFT, no RL (control) |
|
| 101 |
+
| A1 | 0.0 | 0.0 | GRPO-only baseline |
|
| 102 |
+
| A2 | **0.02** | 0.0 | +SDPO small (amplification probe) |
|
| 103 |
+
| A3 | 0.0 | **0.05** | +replay-DPO small (washout probe) |
|
| 104 |
+
| A4 | 0.02 | 0.05 | combined — only after A1–A3 interpretable |
|
| 105 |
+
|
| 106 |
+
`kl_beta=0.02` (KL-to-altered-init) on every RL arm, adaptive to 0.01–0.03
|
| 107 |
+
nats/token; hard-stop/LR-cut if KL > ~0.08. The framework provides the ladder
|
| 108 |
+
via `composer_replication.integrations.altered_minds.channel_ladder_configs()`,
|
| 109 |
+
the structured `MMLUFormatReward` (scores the final answer letter + format
|
| 110 |
+
only — never rationale style, so distorted-but-persuasive reasoning is not
|
| 111 |
+
rewarded), and `dual_kl_logger` (logs KL-to-altered-init **and** KL-to-base each
|
| 112 |
+
step — the washout-vs-amplification instrument).
|
| 113 |
+
|
| 114 |
+
Train for ~500 steps per arm on a single GPU (Qwen-0.5B feasibility-test
|
| 115 |
+
already confirmed; for Llama-8B, use Modal + the framework's `ServerlessExecutor`
|
| 116 |
+
per ADR-005 — local 5090 is too small). The real 8B/LMA-checkpoint run remains
|
| 117 |
+
**user-gated** (it spends grant budget) — ADR-013 ships the capability, proven
|
| 118 |
+
CPU-only on a small model (`examples/altered_minds_channel_ladder/`).
|
| 119 |
|
| 120 |
### Phase 4 — re-evaluate
|
| 121 |
|
|
@@ -1,5 +1,12 @@
|
|
| 1 |
# Vision Validation: Does the Framework Encapsulate the Original Brief?
|
| 2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
> **Status:** Self-audit, 2026-05-25 (Wave 6).
|
| 4 |
> **Question:** Does what we've built reflect what was originally asked for, or did we drift?
|
| 5 |
> **Method:** Recover original brief verbatim → atomic-clause decomposition → traceability matrix → adversarial self-review → user-journey simulation → concrete pass/fail scorecard with gap-closing actions.
|
|
|
|
| 1 |
# Vision Validation: Does the Framework Encapsulate the Original Brief?
|
| 2 |
|
| 3 |
+
> **## Status as of 2026-05-29**
|
| 4 |
+
> The framework is past-skeleton: 8 subpackages (`composer_replication/*`), 210 passing tests, and operational end-to-end examples (`gsm8k_grpo`, `sdpo_with_real_traces_production`). The 3-channel loss, layered hint-generation, trace-ingestion, and DiLoCo have all shipped and been cross-family reviewed.
|
| 5 |
+
>
|
| 6 |
+
> **Two remaining honest gaps:**
|
| 7 |
+
> 1. Docker/TorchForge substrate E2E is hardware-blocked (lacking local multi-GPU rig for the orchestrator layer).
|
| 8 |
+
> 2. Real LMA full-scale run (8B model, 10k SWE-bench traces) is user-budget-gated.
|
| 9 |
+
|
| 10 |
> **Status:** Self-audit, 2026-05-25 (Wave 6).
|
| 11 |
> **Question:** Does what we've built reflect what was originally asked for, or did we drift?
|
| 12 |
> **Method:** Recover original brief verbatim → atomic-clause decomposition → traceability matrix → adversarial self-review → user-journey simulation → concrete pass/fail scorecard with gap-closing actions.
|
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# altered_minds_channel_ladder — B4 SDPO-fires proof (ADR-013)
|
| 2 |
+
|
| 3 |
+
CPU end-to-end proof that the **SDPO channel actually FIRES (nonzero)** on a
|
| 4 |
+
batch built by the production `ComposerDataCollator` with **real, collator-built
|
| 5 |
+
alignment indices**, in the **A2** isolated-channel-ladder config
|
| 6 |
+
(`alpha_sdpo=0.02`).
|
| 7 |
+
|
| 8 |
+
## Why this exists
|
| 9 |
+
|
| 10 |
+
`examples/composer_grpo_sdpo_smoke` proves the SDPO channel is *wired* into a
|
| 11 |
+
live TRL Dr.GRPO loop, but its toy synthetic rollouts carry no error sites, so
|
| 12 |
+
`_compute_sdpo_loss` returns `0` — the channel never actually fires. This script
|
| 13 |
+
closes that gap: it feeds a trace that **has an error turn**, so the collator
|
| 14 |
+
emits `ctx_teacher_input_ids` + `student_response_idx`/`teacher_response_idx`,
|
| 15 |
+
and the SDPO JSD is proven **nonzero** on a differentiable grad path.
|
| 16 |
+
|
| 17 |
+
## Proof achieved: `TinyLM-stub-with-differing-tokens`
|
| 18 |
+
|
| 19 |
+
- **Alignment indices: REAL** — emitted by the shipped `ComposerDataCollator`
|
| 20 |
+
from a genuine error-turn trace, exactly as in a real run.
|
| 21 |
+
- **Model: a deterministic position-dependent `TinyLM` stub** (CPU, no
|
| 22 |
+
download), the same pattern as
|
| 23 |
+
`composer_replication/trainer/tests/test_sdpo_alignment_indices.py`.
|
| 24 |
+
- **Why student tokens are perturbed:** the collator's placeholder-alignment
|
| 25 |
+
trick makes student and teacher carry identical tokens at identical positions
|
| 26 |
+
at the valid aligned indices, so a deterministic stub yields `JSD≈0` there
|
| 27 |
+
(the *correct* answer for a perfectly-aligned identical model). To prove the
|
| 28 |
+
channel genuinely **gathers** the aligned positions and computes a real
|
| 29 |
+
divergence, the student's `input_ids` are made to **differ** from the
|
| 30 |
+
teacher's at exactly those aligned positions — mimicking the hint actually
|
| 31 |
+
changing the recovery tokens (the real-world case where SDPO has signal to
|
| 32 |
+
distill). Different aligned tokens ⇒ different logits ⇒ provably **NONZERO**
|
| 33 |
+
JSD.
|
| 34 |
+
|
| 35 |
+
This is the honest, deterministic CPU proof. Loading a real Qwen2.5-0.5B
|
| 36 |
+
checkpoint is **not required** for the B4 gate and is **not** the same as loading
|
| 37 |
+
an LMA checkpoint (still user-gated, ADR-013 out-of-scope).
|
| 38 |
+
|
| 39 |
+
## Run
|
| 40 |
+
|
| 41 |
+
```bash
|
| 42 |
+
cd <repo> && .venv/bin/python examples/altered_minds_channel_ladder/run.py
|
| 43 |
+
```
|
| 44 |
+
|
| 45 |
+
Optional: `ALTERED_MINDS_REAL_MODEL=1` swaps the stub for a cached
|
| 46 |
+
Qwen2.5-0.5B-Instruct (offline, much slower on CPU). The same token-perturbation
|
| 47 |
+
is still required for a nonzero signal.
|
| 48 |
+
|
| 49 |
+
Exit `0` = PASS (SDPO fired nonzero), `1` = FAIL, `2` = SKIP (deps unavailable).
|
| 50 |
+
|
| 51 |
+
The automated assertion lives in
|
| 52 |
+
`composer_replication/integrations/altered_minds/tests/test_channel_ladder.py::test_b4_sdpo_fires_nonzero_with_real_collator_indices`.
|
|
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""B4 end-to-end CPU proof: the SDPO channel actually FIRES (NONZERO) on a real
|
| 2 |
+
collator-built batch with genuine alignment indices (ADR-013).
|
| 3 |
+
|
| 4 |
+
The existing examples/composer_grpo_sdpo_smoke proves the SDPO channel is *wired*
|
| 5 |
+
into a live TRL Dr.GRPO loop, but its toy synthetic rollouts carry no error
|
| 6 |
+
sites, so _compute_sdpo_loss returns 0 (the channel never actually fires). This
|
| 7 |
+
script closes that gap: it builds a REAL ComposerDataCollator batch from a trace
|
| 8 |
+
that HAS an error turn — so ctx_teacher_input_ids + student/teacher_response_idx
|
| 9 |
+
are emitted by the shipped collator — and proves the SDPO JSD is NONZERO over
|
| 10 |
+
>=1 step, in the A2 ladder config (alpha_sdpo=0.02).
|
| 11 |
+
|
| 12 |
+
PROOF ACHIEVED: stub-with-differing-tokens (NOT a real Qwen checkpoint).
|
| 13 |
+
|
| 14 |
+
- Alignment indices: REAL (production ComposerDataCollator, real error turn).
|
| 15 |
+
- Model: a deterministic position-dependent TinyLM stub (CPU, no download),
|
| 16 |
+
the same pattern used by trainer/tests/test_sdpo_alignment_indices.py.
|
| 17 |
+
- Why perturb student tokens: the collator's placeholder-alignment trick makes
|
| 18 |
+
student & teacher carry identical tokens at identical positions at the valid
|
| 19 |
+
aligned indices, so a deterministic stub yields JSD≈0 there (correct for a
|
| 20 |
+
perfectly-aligned identical model). To prove the channel GATHERS the aligned
|
| 21 |
+
positions and computes a real divergence, the student's input_ids are made to
|
| 22 |
+
DIFFER from the teacher's at exactly those aligned positions — mimicking the
|
| 23 |
+
hint actually changing the recovery tokens (the real-world case where SDPO
|
| 24 |
+
has signal to distill). Different aligned tokens => different logits =>
|
| 25 |
+
provably NONZERO JSD, on a differentiable grad path.
|
| 26 |
+
|
| 27 |
+
To run the SAME assertion against a real Qwen2.5-0.5B-Instruct (if cached
|
| 28 |
+
offline), set ALTERED_MINDS_REAL_MODEL=1 — note that even with a real model the
|
| 29 |
+
NONZERO signal still requires the aligned student/teacher tokens to differ, so
|
| 30 |
+
this script keeps the same token-perturbation; the real-model path only swaps
|
| 31 |
+
the stub for the HF model and is much slower on CPU.
|
| 32 |
+
|
| 33 |
+
Exit 0 = PASS (SDPO fired nonzero), 1 = FAIL, 2 = SKIP (deps unavailable).
|
| 34 |
+
"""
|
| 35 |
+
from __future__ import annotations
|
| 36 |
+
|
| 37 |
+
import os
|
| 38 |
+
import sys
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def _build_tiny_lm(vocab: int):
|
| 42 |
+
import torch
|
| 43 |
+
|
| 44 |
+
class _TinyLM(torch.nn.Module):
|
| 45 |
+
def __init__(self, vocab: int = 64, hidden: int = 8, max_pos: int = 512):
|
| 46 |
+
super().__init__()
|
| 47 |
+
torch.manual_seed(0)
|
| 48 |
+
self.embed = torch.nn.Embedding(vocab, hidden)
|
| 49 |
+
self.pos = torch.nn.Embedding(max_pos, hidden)
|
| 50 |
+
self.head = torch.nn.Linear(hidden, vocab)
|
| 51 |
+
|
| 52 |
+
def forward(self, input_ids):
|
| 53 |
+
T = input_ids.size(1)
|
| 54 |
+
positions = torch.arange(T, device=input_ids.device).unsqueeze(0)
|
| 55 |
+
h = self.embed(input_ids) + self.pos(positions)
|
| 56 |
+
|
| 57 |
+
class _Out:
|
| 58 |
+
pass
|
| 59 |
+
|
| 60 |
+
out = _Out()
|
| 61 |
+
out.logits = self.head(h)
|
| 62 |
+
return out
|
| 63 |
+
|
| 64 |
+
return _TinyLM(vocab=max(vocab, 8))
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class _StubTok:
|
| 68 |
+
pad_token_id = 0
|
| 69 |
+
|
| 70 |
+
def __init__(self) -> None:
|
| 71 |
+
self._v = {"<pad>": 0, "<bos>": 1, "<eos>": 2}
|
| 72 |
+
|
| 73 |
+
def _id(self, w: str) -> int:
|
| 74 |
+
if w not in self._v:
|
| 75 |
+
self._v[w] = len(self._v)
|
| 76 |
+
return self._v[w]
|
| 77 |
+
|
| 78 |
+
def __call__(self, text, **_k):
|
| 79 |
+
return {"input_ids": [self._id(w) for w in text.split()] if text else []}
|
| 80 |
+
|
| 81 |
+
def apply_chat_template(self, messages, tokenize=True, **_k): # noqa: ARG002
|
| 82 |
+
return [
|
| 83 |
+
self._id(w)
|
| 84 |
+
for w in " ".join(m.get("content", "") for m in messages).split()
|
| 85 |
+
]
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def _hint_gen(_kind, _meta):
|
| 89 |
+
return "HINT search before reading"
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def _error_trace():
|
| 93 |
+
return {
|
| 94 |
+
"trace_id": "b4-channel-ladder",
|
| 95 |
+
"turns": [
|
| 96 |
+
{"role": "user", "content": "do the task now"},
|
| 97 |
+
{"role": "user", "content": "tool not found error occurred"},
|
| 98 |
+
{
|
| 99 |
+
"role": "assistant",
|
| 100 |
+
"content": "let me use a real working tool instead now",
|
| 101 |
+
"tool_error": "tool_not_found",
|
| 102 |
+
"error_meta": {},
|
| 103 |
+
},
|
| 104 |
+
],
|
| 105 |
+
"final_reward": 0.0,
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def main() -> int:
|
| 110 |
+
os.environ.setdefault("HF_HUB_OFFLINE", "1")
|
| 111 |
+
os.environ.setdefault("TRANSFORMERS_OFFLINE", "1")
|
| 112 |
+
|
| 113 |
+
try:
|
| 114 |
+
import torch # noqa: F401
|
| 115 |
+
|
| 116 |
+
from composer_replication.integrations.altered_minds import (
|
| 117 |
+
channel_ladder_configs,
|
| 118 |
+
)
|
| 119 |
+
from composer_replication.trainer.composer_trainer import (
|
| 120 |
+
ComposerReplicationTrainer,
|
| 121 |
+
make_dr_grpo_config,
|
| 122 |
+
)
|
| 123 |
+
from composer_replication.trainer.data_collator import (
|
| 124 |
+
CollatorConfig,
|
| 125 |
+
ComposerDataCollator,
|
| 126 |
+
)
|
| 127 |
+
except Exception as e: # noqa: BLE001
|
| 128 |
+
print(f"SKIP: import failed: {e!r}")
|
| 129 |
+
return 2
|
| 130 |
+
|
| 131 |
+
# A2 arm = +SDPO small (alpha_sdpo=0.02), the amplification probe.
|
| 132 |
+
a2 = next(a for a in channel_ladder_configs() if a["arm"] == "A2")
|
| 133 |
+
print(f"[b4] ladder arm A2: alpha_sdpo={a2['alpha_sdpo']} "
|
| 134 |
+
f"beta_replay={a2['beta_replay']} kl_beta={a2['kl_beta']}")
|
| 135 |
+
|
| 136 |
+
# make_dr_grpo_config is exercised to prove the config wiring is intact
|
| 137 |
+
# (the actual TLM stub forward does not need a GRPOConfig, but a real A2
|
| 138 |
+
# runner would pass this through to ComposerReplicationTrainer).
|
| 139 |
+
try:
|
| 140 |
+
cfg = make_dr_grpo_config(output_dir="/tmp/b4_ladder_out", report_to=[])
|
| 141 |
+
print(f"[b4] Dr.GRPO config OK: loss_type={cfg.loss_type} "
|
| 142 |
+
f"scale_rewards={cfg.scale_rewards} num_iterations={cfg.num_iterations}")
|
| 143 |
+
except Exception as e: # noqa: BLE001
|
| 144 |
+
print(f"[b4] (config build skipped: {e!r})")
|
| 145 |
+
|
| 146 |
+
# --- REAL collator-built batch with a genuine error turn ---
|
| 147 |
+
tok = _StubTok()
|
| 148 |
+
collator = ComposerDataCollator(
|
| 149 |
+
tokenizer=tok,
|
| 150 |
+
config=CollatorConfig(hint_generator=_hint_gen, enable_replay_dpo=False),
|
| 151 |
+
)
|
| 152 |
+
batch = collator([_error_trace()])
|
| 153 |
+
|
| 154 |
+
if batch.get("ctx_teacher_input_ids") is None or batch["ctx_teacher_input_ids"].numel() == 0:
|
| 155 |
+
print("FAIL: collator emitted no error-site teacher context.")
|
| 156 |
+
return 1
|
| 157 |
+
s_idx = batch["student_response_idx"]
|
| 158 |
+
s_valid = batch["student_response_valid"]
|
| 159 |
+
if int(s_valid.sum()) == 0:
|
| 160 |
+
print("FAIL: no valid aligned response positions.")
|
| 161 |
+
return 1
|
| 162 |
+
print(f"[b4] collator emitted real alignment indices: "
|
| 163 |
+
f"student_response_idx shape={tuple(s_idx.shape)}, "
|
| 164 |
+
f"valid positions={int(s_valid.sum())}")
|
| 165 |
+
|
| 166 |
+
# --- Make the student tokens differ from teacher at aligned positions ---
|
| 167 |
+
student_ids = batch["input_ids"].clone()
|
| 168 |
+
vocab_ceiling = int(
|
| 169 |
+
max(batch["input_ids"].max(), batch["ctx_teacher_input_ids"].max())
|
| 170 |
+
) + 8
|
| 171 |
+
for b in range(s_idx.shape[0]):
|
| 172 |
+
for k in range(s_idx.shape[1]):
|
| 173 |
+
if bool(s_valid[b, k]):
|
| 174 |
+
pos = int(s_idx[b, k])
|
| 175 |
+
student_ids[b, pos] = (int(student_ids[b, pos]) + 3) % vocab_ceiling
|
| 176 |
+
batch["input_ids"] = student_ids
|
| 177 |
+
|
| 178 |
+
real_model = os.environ.get("ALTERED_MINDS_REAL_MODEL") == "1"
|
| 179 |
+
if real_model:
|
| 180 |
+
try:
|
| 181 |
+
from transformers import AutoModelForCausalLM
|
| 182 |
+
model_id = os.environ.get("SMOKE_MODEL", "Qwen/Qwen2.5-0.5B-Instruct")
|
| 183 |
+
print(f"[b4] loading real model {model_id} (CPU, slow) ...")
|
| 184 |
+
model = AutoModelForCausalLM.from_pretrained(model_id)
|
| 185 |
+
print("[b4] real model loaded; proof path = REAL-MODEL")
|
| 186 |
+
except Exception as e: # noqa: BLE001
|
| 187 |
+
print(f"[b4] real model unavailable ({e!r}); falling back to TinyLM stub")
|
| 188 |
+
model = _build_tiny_lm(vocab_ceiling)
|
| 189 |
+
real_model = False
|
| 190 |
+
else:
|
| 191 |
+
model = _build_tiny_lm(vocab_ceiling)
|
| 192 |
+
|
| 193 |
+
# --- A2 config: SDPO-only small (alpha_sdpo=0.02), strict alignment ---
|
| 194 |
+
obj = ComposerReplicationTrainer.__new__(ComposerReplicationTrainer)
|
| 195 |
+
obj.alpha_sdpo = float(a2["alpha_sdpo"])
|
| 196 |
+
obj.sdpo_jsd_beta = 0.5
|
| 197 |
+
obj.sdpo_temperature = 1.0
|
| 198 |
+
obj.sdpo_token_clip = None
|
| 199 |
+
obj.strict_sdpo_alignment = True
|
| 200 |
+
|
| 201 |
+
loss = obj._compute_sdpo_loss(model, batch)
|
| 202 |
+
val = float(loss.detach())
|
| 203 |
+
print("=" * 64)
|
| 204 |
+
print(f" proof path: {'REAL-MODEL' if real_model else 'TinyLM-stub-with-differing-tokens'}")
|
| 205 |
+
print(f" SDPO JSD (sdpo_kl): {val:.6f}")
|
| 206 |
+
print(f" requires_grad: {loss.requires_grad}")
|
| 207 |
+
|
| 208 |
+
if not (val == val) or val in (float("inf"), float("-inf")):
|
| 209 |
+
print(" RESULT: FAIL ❌ (loss not finite)")
|
| 210 |
+
return 1
|
| 211 |
+
if val <= 1e-6:
|
| 212 |
+
print(" RESULT: FAIL ❌ (SDPO channel did not fire — JSD ~0)")
|
| 213 |
+
return 1
|
| 214 |
+
|
| 215 |
+
(obj.alpha_sdpo * loss).backward()
|
| 216 |
+
grad_norm = sum(
|
| 217 |
+
float(p.grad.norm()) for p in model.parameters() if p.grad is not None
|
| 218 |
+
)
|
| 219 |
+
print(f" grad norm into model: {grad_norm:.6f}")
|
| 220 |
+
if grad_norm <= 0.0:
|
| 221 |
+
print(" RESULT: FAIL ❌ (no gradient flowed from SDPO loss)")
|
| 222 |
+
return 1
|
| 223 |
+
|
| 224 |
+
print(" RESULT: PASS ✅ (SDPO channel FIRED nonzero via real collator indices)")
|
| 225 |
+
return 0
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
if __name__ == "__main__":
|
| 229 |
+
sys.exit(main())
|