Codeseys commited on
Commit
21647a4
·
1 Parent(s): d02d724

feat(wave-b): ADR-013 LMA integration + B4 end-to-end SDPO-fires proof + doc refresh

Browse files

ADR-013: composer_replication/integrations/altered_minds/ (generic, framework-side)
- MMLUFormatReward: structured-answer reward, scores only the final letter (not
rationale style), unparseable/multiple-answer penalties, option-order randomization,
always-C exploit detectable via logged option distribution.
- dual_kl_logger: KL(policy||altered-init) AND KL(policy||unaltered-base) as the
washout-vs-amplification instrument (neither optimized).
- channel_ladder_configs: A0-A4 isolated-channel ladder (alpha=0.02/beta=0.05),
replacing the uninterpretable combined alpha=0.2/beta=0.4 recipe per the
cross-family research critique (SDPO can AMPLIFY an altered model).

B4: examples/altered_minds_channel_ladder/run.py PROVES the SDPO channel FIRES
nonzero (JSD=0.0565, grad flows) through the REAL shipped collator alignment
indices — the thing the old smoke could not show (it only proved init). Honest
stub-with-differing-tokens proof; real-model path gated behind ALTERED_MINDS_REAL_MODEL=1.

ALTERED_MINDS_TIE_IN.md Phase-3 hyperparams superseded -> ADR-013 ladder.
BACKLOG.md + VISION_VALIDATION.md refreshed to actual shipped state.

227 passed / 16 skipped (was 210). NO real LMA checkpoint / Modal / budget spend
(user-gated). Workers: 2x Opus-4.8 + Gemini-3.1-pro (docs).

BACKLOG.md CHANGED
@@ -1,8 +1,23 @@
1
  # Backlog — Composer 2.5 Replication Framework
2
 
3
- Imported from `docs/VISION_VALIDATION.md` § 6 (gaps) + § 9 (gap-closers) at 2026-05-26.
4
 
5
- ## Active items (CPU-only, no GPU budget)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  ### Spike 006 — Real HF model smoke (Wave 7)
8
 
@@ -60,23 +75,15 @@ Imported from `docs/VISION_VALIDATION.md` § 6 (gaps) + § 9 (gap-closers) at 20
60
  4. README quickstart updated to `pip install -e .` + `python examples/qwen3_05b_quickstart/run.py`.
61
  5. `pip install -e .` succeeds and quickstart runs end-to-end on CPU.
62
 
63
- **Estimate**: half a day, CPU.
64
-
65
- ## Modal-gated (if budget allows after gap-closers)
66
-
67
- ### Spike 002a-mini Real GPU smoke (Phase 10)
68
-
69
- **Closes**: the "did we ever run gradients on GPU" ambiguity currently everything is CPU-only.
70
-
71
- **Goal**: dispatch a 30-min A10G smoke on Modal that runs Spike 006 unchanged on GPU, verifies bf16 numerics, captures memory + step-time.
72
-
73
- **Acceptance**:
74
- 1. ADR-001 says Modal is the right choice for this workload + estimate is < $5.
75
- 2. Modal app builds, runs `composer_total_loss` for 50 steps on Qwen2.5-0.5B-Instruct.
76
- 3. Loss curve + memory profile saved to `spikes/002a-mini/` and pulled to local.
77
- 4. No new shape / dtype bug surfaced vs CPU run.
78
-
79
- **Estimate**: $1–3, 30 min wall-clock.
80
 
81
  ## Deferred (post-loop, GPU-gated)
82
 
 
1
  # Backlog — Composer 2.5 Replication Framework
2
 
3
+ Updated 2026-05-29 to reflect shipped waves (ingestion, diloco, packaging, datagen+RL, ADR-011/012/013, cross-family review).
4
 
5
+ ## Active items / Honest Gaps
6
+
7
+ ### Framework/Docker substrate E2E (Hardware-blocked)
8
+ - We lack the local multi-node GPU environment to run the true 8-node DiLoCo + Docker/TorchForge orchestrator E2E tests. Currently isolated to unit-level and single-node pseudo-gradient checks.
9
+
10
+ ### Real 8B LMA run (User-budget-gated)
11
+ - The framework is proven on Qwen-0.5B and 1.5B (GSM8K/SDPO math traces).
12
+ - The ultimate goal (Llama-3-8B full LMA run with α/β ablation over 10k SWE-bench traces) requires a multi-GPU Modal drop + significant compute budget.
13
+
14
+ ## Modal-gated (if budget allows after gap-closers)
15
+
16
+ ### Spike 002a-mini — Real GPU smoke (Phase 10)
17
+ **Closes**: the "did we ever run gradients on GPU" ambiguity — currently everything is CPU-only.
18
+ - Goal: dispatch a 30-min A10G smoke on Modal that runs Qwen2.5-0.5B-Instruct natively on GPU.
19
+
20
+ ## Shipped (Past-Skeleton)
21
 
22
  ### Spike 006 — Real HF model smoke (Wave 7)
23
 
 
75
  4. README quickstart updated to `pip install -e .` + `python examples/qwen3_05b_quickstart/run.py`.
76
  5. `pip install -e .` succeeds and quickstart runs end-to-end on CPU.
77
 
78
+ ### Post-Skeleton Waves (Datagen, Alignment, Quality)
79
+ - **Trace Ingestion**: Shipped (`composer_replication/ingestion/`).
80
+ - **DiLoCo**: Shipped (`composer_replication/diloco/` outer-loop pseudo optimizer).
81
+ - **Packaging**: Shipped (`pip install -e .` works perfectly).
82
+ - **ADR-008/009/010 (Datagen, Layered Hints, Dr.GRPO+SDPO)**: Shipped, examples documented.
83
+ - **Cross-Family Architectural Review**: Shipped (`docs/reviews/cross-family-adr-008-009-010-2026-05-29/`).
84
+ - **Alignment / V&V Closure**: ADR-011 (SDPO alignment indices), ADR-012 (close review findings), ADR-013 (LMA integration channel-ladder) shipped.
85
+ - **Test Suites**: 210 passed / 16 skipped.
86
+ - **Real Examples**: `examples/gsm8k_grpo/`, `examples/sdpo_with_real_traces_production/`.
 
 
 
 
 
 
 
 
87
 
88
  ## Deferred (post-loop, GPU-gated)
89
 
composer_replication/integrations/__init__.py ADDED
File without changes
composer_replication/integrations/altered_minds/__init__.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """altered_minds — framework-side, generic LMA integration glue (ADR-013).
2
+
3
+ This package is the *model-agnostic* scaffold that lets the Composer Replication
4
+ Framework drive the sister project llm-mental-alterations (LMA): take a
5
+ personality-altered SFT checkpoint and apply the framework's 3-channel RL to ask
6
+ whether task-driven RL washes out, preserves, or AMPLIFIES the alteration's
7
+ cognitive-distortion signature.
8
+
9
+ Nothing here loads an LMA checkpoint, calls Modal, or spends budget — that is
10
+ explicitly user-gated (ADR-013 "out of scope"). This package provides:
11
+
12
+ - ``MMLUFormatReward`` : structured-answer reward (final letter + format
13
+ only; never rationale style). Plus
14
+ ``randomize_options`` and a logged option
15
+ distribution so an "always C" exploit is
16
+ detectable.
17
+ - ``dual_kl_logger`` : logs KL(policy||altered_init) AND KL(policy||base)
18
+ each step — the washout/amplification instrument.
19
+ - ``channel_ladder_configs``: the A0-A4 isolated-channel ladder that REPLACES
20
+ the old combined alpha=0.2/beta=0.4 recipe.
21
+
22
+ See docs/adrs/ADR-013-lma-integration-channel-ladder.md.
23
+ """
24
+ from __future__ import annotations
25
+
26
+ from composer_replication.integrations.altered_minds.kl_logging import (
27
+ dual_kl_logger,
28
+ token_mean_kl,
29
+ )
30
+ from composer_replication.integrations.altered_minds.ladder import (
31
+ LADDER_KL_BETA,
32
+ channel_ladder_configs,
33
+ )
34
+ from composer_replication.integrations.altered_minds.reward import (
35
+ MMLUFormatReward,
36
+ parse_final_answer,
37
+ randomize_options,
38
+ )
39
+
40
+ __all__ = [
41
+ "MMLUFormatReward",
42
+ "parse_final_answer",
43
+ "randomize_options",
44
+ "dual_kl_logger",
45
+ "token_mean_kl",
46
+ "channel_ladder_configs",
47
+ "LADDER_KL_BETA",
48
+ ]
composer_replication/integrations/altered_minds/kl_logging.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """kl_logging.py — dual_kl_logger (ADR-013, framework-side, generic).
2
+
3
+ The washout/amplification instrument. Given per-token logprobs from three
4
+ forward passes on the SAME answer+reasoning tokens:
5
+
6
+ - policy: the model currently being RL-trained
7
+ - altered_init: the altered SFT checkpoint the run STARTED from (the locus
8
+ of the cognitive-distortion signature)
9
+ - unaltered_base: the original base model BEFORE personality SFT
10
+
11
+ returns ``{'kl_to_altered_init': float, 'kl_to_base': float}``.
12
+
13
+ NEITHER KL is optimized by default — both are diagnostics:
14
+ - ``kl_to_altered_init`` rising means the policy is moving AWAY from the
15
+ altered checkpoint (task-RL is *changing* the alteration).
16
+ - ``kl_to_base`` measures distance to the unaltered base. If
17
+ ``kl_to_base`` SHRINKS while ``kl_to_altered_init`` grows, the alteration
18
+ is WASHING OUT (the policy drifts back toward base). If ``kl_to_base``
19
+ GROWS faster than ``kl_to_altered_init``, the alteration is being AMPLIFIED
20
+ (the policy moves further from base than the altered init already was) —
21
+ the ADR-013 amplification hypothesis, most likely on the SDPO channel.
22
+
23
+ Token-mean KL is used (mean over the masked answer+reasoning tokens), the
24
+ standard diagnostic convention. The math is the discrete KL between the two
25
+ softmax distributions implied by the logprob tensors:
26
+
27
+ KL(p || q) = sum_v p_v (log p_v - log q_v)
28
+
29
+ where ``p`` is the policy's per-token distribution. This is unit-testable on
30
+ toy tensors: KL(p || p) == 0, and KL grows monotonically as the policy moves.
31
+ """
32
+ from __future__ import annotations
33
+
34
+ from typing import Any
35
+
36
+ import torch
37
+
38
+ __all__ = ["dual_kl_logger", "token_mean_kl"]
39
+
40
+
41
+ def _as_log_probs(logprobs: torch.Tensor) -> torch.Tensor:
42
+ """Normalize an input that may be raw logits OR already-log-probs to valid
43
+ log-probabilities along the last (vocab) dim.
44
+
45
+ We re-apply ``log_softmax`` defensively: it is idempotent on a genuine
46
+ log-prob tensor up to floating point (log_softmax of log-probs == log-probs
47
+ since they already sum-exp to 1), and converts raw logits correctly. This
48
+ makes the logger robust to either calling convention.
49
+ """
50
+ return torch.log_softmax(logprobs.to(torch.float64), dim=-1)
51
+
52
+
53
+ def token_mean_kl(
54
+ policy_logprobs: torch.Tensor,
55
+ ref_logprobs: torch.Tensor,
56
+ mask: torch.Tensor | None = None,
57
+ ) -> float:
58
+ """Token-mean KL(policy || ref) over distributions on the last dim.
59
+
60
+ Args:
61
+ policy_logprobs: (..., V) logits or log-probs for the policy.
62
+ ref_logprobs: (..., V) logits or log-probs for the reference.
63
+ mask: optional (...,) mask of tokens to include (1/True = include). If
64
+ None, all tokens count.
65
+
66
+ Returns:
67
+ scalar token-mean KL as a python float (>= 0 up to float error).
68
+ """
69
+ log_p = _as_log_probs(policy_logprobs)
70
+ log_q = _as_log_probs(ref_logprobs)
71
+ p = log_p.exp()
72
+ # per-token KL: sum over vocab of p * (log p - log q)
73
+ per_token = (p * (log_p - log_q)).sum(dim=-1) # (...,)
74
+
75
+ if mask is not None:
76
+ m = mask.to(per_token.dtype)
77
+ denom = m.sum()
78
+ if float(denom) == 0.0:
79
+ return 0.0
80
+ return float((per_token * m).sum() / denom)
81
+ return float(per_token.mean())
82
+
83
+
84
+ def dual_kl_logger(
85
+ policy_logprobs: torch.Tensor,
86
+ altered_init_logprobs: torch.Tensor,
87
+ unaltered_base_logprobs: torch.Tensor,
88
+ mask: torch.Tensor | None = None,
89
+ **_: Any,
90
+ ) -> dict[str, float]:
91
+ """Compute the two diagnostic KLs for a step.
92
+
93
+ Args:
94
+ policy_logprobs: (..., V) policy logits/log-probs on the
95
+ answer+reasoning tokens.
96
+ altered_init_logprobs: (..., V) for the altered SFT init.
97
+ unaltered_base_logprobs:(..., V) for the unaltered base.
98
+ mask: optional (...,) token mask (answer+reasoning tokens to score).
99
+
100
+ Returns:
101
+ ``{'kl_to_altered_init': float, 'kl_to_base': float}``.
102
+ """
103
+ return {
104
+ "kl_to_altered_init": token_mean_kl(
105
+ policy_logprobs, altered_init_logprobs, mask
106
+ ),
107
+ "kl_to_base": token_mean_kl(
108
+ policy_logprobs, unaltered_base_logprobs, mask
109
+ ),
110
+ }
composer_replication/integrations/altered_minds/ladder.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ladder.py — channel_ladder_configs (ADR-013, the experiment design).
2
+
3
+ The isolated-channel ladder REPLACES the old combined alpha=0.2/beta=0.4 recipe
4
+ (superseded; see docs/ALTERED_MINDS_TIE_IN.md Phase 3). Per ADR-013, a combined
5
+ run confounds four effects (task RL, self-distillation of altered reasoning,
6
+ frontier-teacher imitation, KL anchoring) and is scientifically uninterpretable.
7
+ Worse, SDPO against the altered model's OWN hint-conditioned forward pass can
8
+ AMPLIFY the distortion, so it is an experimental intervention, not a stabilizer.
9
+
10
+ The ladder isolates channels so each effect is attributable:
11
+
12
+ | Arm | alpha_sdpo | beta_replay | Purpose |
13
+ |-----|------------|-------------|----------------------------------|
14
+ | A0 | — | — | altered SFT, no RL (control) |
15
+ | A1 | 0.0 | 0.0 | GRPO-only baseline |
16
+ | A2 | 0.02 | 0.0 | +SDPO small (amplification probe)|
17
+ | A3 | 0.0 | 0.05 | +replay-DPO small (washout probe)|
18
+ | A4 | 0.02 | 0.05 | combined — ONLY after A1-A3 |
19
+
20
+ ``kl_beta`` (KL-to-altered-init coef) = 0.02 for all RL arms. A0 is a sentinel
21
+ (no RL) so its alpha/beta/kl_beta are None.
22
+ """
23
+ from __future__ import annotations
24
+
25
+ from typing import Any
26
+
27
+ __all__ = ["channel_ladder_configs", "LADDER_KL_BETA"]
28
+
29
+ #: KL-to-altered-init coefficient applied to every RL arm (A1-A4).
30
+ LADDER_KL_BETA = 0.02
31
+
32
+
33
+ def channel_ladder_configs() -> list[dict[str, Any]]:
34
+ """Return the ordered A0-A4 arm configs.
35
+
36
+ Each arm is a dict with keys: ``arm``, ``alpha_sdpo``, ``beta_replay``,
37
+ ``kl_beta``, ``note``. A0 is the no-RL sentinel (alpha/beta/kl_beta = None).
38
+
39
+ A runner sweeps these with IDENTICAL seeds/prompts so any observed change in
40
+ the alteration signature is attributable to the single channel that arm
41
+ turns on relative to A1.
42
+ """
43
+ return [
44
+ {
45
+ "arm": "A0",
46
+ "alpha_sdpo": None,
47
+ "beta_replay": None,
48
+ "kl_beta": None,
49
+ "note": (
50
+ "Control: altered SFT checkpoint, NO RL. Sentinel arm used to "
51
+ "anchor the pre-RL alteration signature."
52
+ ),
53
+ },
54
+ {
55
+ "arm": "A1",
56
+ "alpha_sdpo": 0.0,
57
+ "beta_replay": 0.0,
58
+ "kl_beta": LADDER_KL_BETA,
59
+ "note": (
60
+ "GRPO-only baseline (both extra channels OFF). Isolates the "
61
+ "effect of task-driven RL alone on the alteration."
62
+ ),
63
+ },
64
+ {
65
+ "arm": "A2",
66
+ "alpha_sdpo": 0.02,
67
+ "beta_replay": 0.0,
68
+ "kl_beta": LADDER_KL_BETA,
69
+ "note": (
70
+ "+SDPO small (amplification probe). SDPO ONLY vs A1: tests "
71
+ "whether self-distillation against the altered model's own "
72
+ "hint-conditioned forward pass AMPLIFIES the distortion."
73
+ ),
74
+ },
75
+ {
76
+ "arm": "A3",
77
+ "alpha_sdpo": 0.0,
78
+ "beta_replay": 0.05,
79
+ "kl_beta": LADDER_KL_BETA,
80
+ "note": (
81
+ "+replay-DPO small (washout probe). Trace-replay-DPO ONLY vs "
82
+ "A1: tests whether frontier-teacher disagreement WASHES OUT the "
83
+ "alteration toward base."
84
+ ),
85
+ },
86
+ {
87
+ "arm": "A4",
88
+ "alpha_sdpo": 0.02,
89
+ "beta_replay": 0.05,
90
+ "kl_beta": LADDER_KL_BETA,
91
+ "note": (
92
+ "Combined — run ONLY after A1-A3 are interpretable. Confounds "
93
+ "channels by design; meaningful only as a capstone once the "
94
+ "isolated arms are understood."
95
+ ),
96
+ },
97
+ ]
composer_replication/integrations/altered_minds/reward.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """reward.py — MMLUFormatReward (ADR-013, framework-side, generic).
2
+
3
+ A structured-answer reward for RL on MMLU-style multiple-choice tasks. It scores
4
+ ONLY the final answer letter + format validity — never the rationale's style or
5
+ content. This is deliberate: the north-star use case (ADR-013) drives RL on a
6
+ *personality-altered* model, and rewarding "persuasive" rationale would reward
7
+ the very cognitive-distortion signature we are trying to measure rather than
8
+ distort the reward toward it.
9
+
10
+ Scoring per completion:
11
+ +1.0 final answer parses and equals the gold letter
12
+ 0.0 final answer parses but is wrong
13
+ -0.2 no parseable final-answer marker (unparseable)
14
+ -0.1 multiple DISTINCT final-answer markers present (format hacking)
15
+ -len_penalty small penalty past a rationale character cap
16
+
17
+ Parsing accepts (case-insensitive, last match wins for the canonical letter):
18
+ - ``Answer: X`` (X in A-D)
19
+ - JSON ``{"answer": "X"}``
20
+
21
+ Exploit detection: ``MMLUFormatReward`` keeps a running count of chosen letters
22
+ (``option_distribution``) so an "always C" / option-prior exploit is detectable
23
+ by inspecting that distribution after a run. A companion ``randomize_options``
24
+ helper shuffles option order with an original->shuffled label remap so the
25
+ training data itself can be de-biased.
26
+ """
27
+ from __future__ import annotations
28
+
29
+ import json
30
+ import re
31
+ from collections import Counter
32
+ from dataclasses import dataclass, field
33
+ from typing import Any
34
+
35
+ __all__ = ["MMLUFormatReward", "randomize_options", "parse_final_answer"]
36
+
37
+ _VALID_LETTERS = ("A", "B", "C", "D")
38
+
39
+ # ``Answer: X`` — tolerant of whitespace, optional markdown bold/asterisks.
40
+ _ANSWER_RE = re.compile(r"answer\s*[:\-]\s*\*{0,2}([A-D])\b", re.IGNORECASE)
41
+ # JSON ``{"answer": "X"}`` — extract the value of an "answer" key.
42
+ _JSON_ANSWER_RE = re.compile(
43
+ r'["\']answer["\']\s*:\s*["\']([A-D])["\']', re.IGNORECASE
44
+ )
45
+
46
+
47
+ def _find_markers(text: str) -> list[str]:
48
+ """Return ALL final-answer letters found (uppercased), in order of appearance.
49
+
50
+ Used both to pick the canonical answer (last match wins) and to detect the
51
+ multiple-distinct-markers format-hacking case.
52
+ """
53
+ markers: list[tuple[int, str]] = []
54
+ for m in _ANSWER_RE.finditer(text or ""):
55
+ markers.append((m.start(), m.group(1).upper()))
56
+ for m in _JSON_ANSWER_RE.finditer(text or ""):
57
+ markers.append((m.start(), m.group(1).upper()))
58
+ markers.sort(key=lambda p: p[0])
59
+ return [letter for _, letter in markers]
60
+
61
+
62
+ def parse_final_answer(completion: str) -> tuple[str | None, int]:
63
+ """Parse the final answer letter from a completion.
64
+
65
+ Returns ``(letter_or_None, n_distinct_markers)``. ``letter`` is the LAST
66
+ marker found (last match wins). ``n_distinct_markers`` counts DISTINCT
67
+ letters across all markers (so two ``Answer: C`` are not penalized, but
68
+ ``Answer: A ... Answer: B`` is).
69
+ """
70
+ markers = _find_markers(completion)
71
+ if not markers:
72
+ return None, 0
73
+ distinct = len(set(markers))
74
+ return markers[-1], distinct
75
+
76
+
77
+ @dataclass
78
+ class MMLUFormatReward:
79
+ """Callable reward_fn(prompts, completions, *, answers, **kwargs) -> list[float].
80
+
81
+ Args:
82
+ rationale_char_cap: completions longer than this incur a small length
83
+ penalty (``length_penalty_per_char`` per char past the cap). Caps
84
+ verbosity without scoring rationale content.
85
+ length_penalty_per_char: per-character penalty past the cap.
86
+ correct_reward / wrong_reward / unparseable_reward /
87
+ multiple_answers_reward: the scalar rewards for each outcome.
88
+
89
+ Side effect: ``option_distribution`` (a Counter over chosen letters) and
90
+ ``n_scored`` accumulate across calls so an "always C" exploit is detectable
91
+ via ``exploit_report()``.
92
+ """
93
+
94
+ rationale_char_cap: int = 512
95
+ length_penalty_per_char: float = 0.001
96
+ correct_reward: float = 1.0
97
+ wrong_reward: float = 0.0
98
+ unparseable_reward: float = -0.2
99
+ multiple_answers_reward: float = -0.1
100
+ option_distribution: Counter = field(default_factory=Counter)
101
+ n_scored: int = 0
102
+
103
+ def __call__(
104
+ self,
105
+ prompts: Any = None,
106
+ completions: list[str] | None = None,
107
+ *,
108
+ answers: list[str] | None = None,
109
+ **kwargs: Any,
110
+ ) -> list[float]:
111
+ """Score a batch of completions against gold ``answers`` (letters A-D).
112
+
113
+ ``prompts`` is accepted for the TRL reward-fn signature but unused
114
+ (we score the completion text only). ``answers`` is required.
115
+ """
116
+ if completions is None:
117
+ completions = []
118
+ if answers is None:
119
+ raise ValueError(
120
+ "MMLUFormatReward requires `answers` (the gold letters, one per "
121
+ "completion). Pass via reward_fn(..., answers=[...])."
122
+ )
123
+ if len(answers) != len(completions):
124
+ raise ValueError(
125
+ f"answers/completions length mismatch: {len(answers)} vs "
126
+ f"{len(completions)}."
127
+ )
128
+
129
+ rewards: list[float] = []
130
+ for completion, gold in zip(completions, answers):
131
+ rewards.append(self._score_one(completion, gold))
132
+ return rewards
133
+
134
+ def _score_one(self, completion: str, gold: str) -> float:
135
+ letter, n_distinct = parse_final_answer(completion)
136
+ self.n_scored += 1
137
+
138
+ if letter is None:
139
+ # Unparseable: no usable final-answer marker. Length penalty does
140
+ # not apply (we never even parsed a letter to reward/penalize).
141
+ return self.unparseable_reward
142
+
143
+ # Log the chosen letter for exploit detection (always-C etc.).
144
+ self.option_distribution[letter] += 1
145
+
146
+ if n_distinct > 1:
147
+ # Multiple DISTINCT markers — format hacking. Penalize regardless
148
+ # of correctness (the model is hedging / gaming the parser).
149
+ base = self.multiple_answers_reward
150
+ elif gold is not None and letter == str(gold).strip().upper():
151
+ base = self.correct_reward
152
+ else:
153
+ base = self.wrong_reward
154
+
155
+ return base - self._length_penalty(completion)
156
+
157
+ def _length_penalty(self, completion: str) -> float:
158
+ over = max(0, len(completion or "") - self.rationale_char_cap)
159
+ return self.length_penalty_per_char * over
160
+
161
+ # ------------------------------------------------------------------
162
+ # Exploit detection
163
+ # ------------------------------------------------------------------
164
+ def exploit_report(self) -> dict[str, Any]:
165
+ """Summarize the chosen-letter distribution so an option-prior exploit
166
+ (e.g. "always C") is detectable.
167
+
168
+ Returns a dict with the raw counts, the most common letter, and its
169
+ fraction of all parsed answers. A healthy run is ~uniform over A-D; a
170
+ fraction near 1.0 for a single letter is the exploit signature.
171
+ """
172
+ total = sum(self.option_distribution.values())
173
+ if total == 0:
174
+ return {
175
+ "counts": {},
176
+ "total_parsed": 0,
177
+ "most_common": None,
178
+ "max_fraction": 0.0,
179
+ }
180
+ letter, count = self.option_distribution.most_common(1)[0]
181
+ return {
182
+ "counts": dict(self.option_distribution),
183
+ "total_parsed": total,
184
+ "most_common": letter,
185
+ "max_fraction": count / total,
186
+ }
187
+
188
+
189
+ def randomize_options(
190
+ item: dict[str, Any], seed: int
191
+ ) -> tuple[dict[str, Any], dict[str, str]]:
192
+ """Shuffle the multiple-choice option order, tracking original->shuffled letters.
193
+
194
+ Args:
195
+ item: a dict with ``options`` (list[str], A-first ordering) and
196
+ ``answer`` (the gold letter, A-D). Other keys are passed through.
197
+ seed: deterministic RNG seed for the shuffle.
198
+
199
+ Returns:
200
+ ``(shuffled_item, label_remap)`` where ``shuffled_item`` has the options
201
+ reordered and its ``answer`` updated to the gold option's NEW letter, and
202
+ ``label_remap`` maps each ORIGINAL letter -> its NEW (shuffled) letter.
203
+
204
+ This de-biases an option-prior exploit at the data level: if the gold answer
205
+ is no longer correlated with a fixed position, "always C" stops working.
206
+ """
207
+ import random
208
+
209
+ options = list(item.get("options", []))
210
+ n = len(options)
211
+ if n == 0:
212
+ return dict(item), {}
213
+ orig_letters = [chr(ord("A") + i) for i in range(n)]
214
+
215
+ rng = random.Random(seed)
216
+ perm = list(range(n))
217
+ rng.shuffle(perm)
218
+ # perm[new_pos] = old_pos => option at new_pos is the old option perm[new_pos]
219
+ shuffled_options = [options[perm[new]] for new in range(n)]
220
+
221
+ # original letter -> new letter: old index `perm[new]` moved to position `new`.
222
+ label_remap: dict[str, str] = {}
223
+ for new_pos, old_pos in enumerate(perm):
224
+ label_remap[orig_letters[old_pos]] = orig_letters[new_pos]
225
+
226
+ shuffled_item = dict(item)
227
+ shuffled_item["options"] = shuffled_options
228
+ gold = str(item.get("answer", "")).strip().upper()
229
+ if gold in label_remap:
230
+ shuffled_item["answer"] = label_remap[gold]
231
+ return shuffled_item, label_remap
composer_replication/integrations/altered_minds/tests/__init__.py ADDED
File without changes
composer_replication/integrations/altered_minds/tests/test_channel_ladder.py ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for the ADR-013 altered_minds integration glue + the B4 SDPO-fires proof.
2
+
3
+ Covers the ADR-013 acceptance gate:
4
+ - MMLUFormatReward: correct→+1, wrong→0, unparseable→−0.2, multiple→−0.1,
5
+ length-penalty, and an "always C" option-prior exploit is DETECTABLE via the
6
+ logged option distribution. Rationale style is NOT scored.
7
+ - dual_kl_logger: KL(p‖p)==0 and KL grows as the policy moves.
8
+ - channel_ladder_configs: A1 both off, A2 SDPO-only, A3 replay-only.
9
+ - B4: the SDPO channel actually FIRES (NONZERO loss) with REAL collator-built
10
+ alignment indices. See the module docstring on test_b4_* for the honest
11
+ stub-vs-real note.
12
+
13
+ All CPU-only and fast (stub tokenizer + tiny model — no model download).
14
+ """
15
+ from __future__ import annotations
16
+
17
+ import pytest
18
+ import torch
19
+
20
+ from composer_replication.integrations.altered_minds import (
21
+ MMLUFormatReward,
22
+ channel_ladder_configs,
23
+ dual_kl_logger,
24
+ randomize_options,
25
+ )
26
+
27
+
28
+ # ===========================================================================
29
+ # MMLUFormatReward
30
+ # ===========================================================================
31
+
32
+ def test_reward_correct_wrong_unparseable_multiple():
33
+ r = MMLUFormatReward()
34
+ completions = [
35
+ "Reasoning blah. Answer: B", # correct
36
+ "I think it's Answer: A", # wrong (gold C)
37
+ "no marker here at all", # unparseable
38
+ "Answer: A then actually Answer: D", # multiple distinct
39
+ '{"answer": "C"}', # JSON correct
40
+ ]
41
+ answers = ["B", "C", "B", "A", "C"]
42
+ out = r(prompts=None, completions=completions, answers=answers)
43
+ assert out[0] == pytest.approx(1.0) # correct
44
+ assert out[1] == pytest.approx(0.0) # wrong
45
+ assert out[2] == pytest.approx(-0.2) # unparseable
46
+ assert out[3] == pytest.approx(-0.1) # multiple distinct markers
47
+ assert out[4] == pytest.approx(1.0) # JSON correct
48
+
49
+
50
+ def test_reward_last_match_wins_same_letter_not_penalized():
51
+ """Two markers of the SAME letter is not 'multiple distinct' — last wins."""
52
+ r = MMLUFormatReward()
53
+ out = r(completions=["Answer: C ... so my final Answer: C"], answers=["C"])
54
+ assert out[0] == pytest.approx(1.0)
55
+
56
+
57
+ def test_reward_case_insensitive_and_json_variants():
58
+ r = MMLUFormatReward()
59
+ out = r(
60
+ completions=["answer: d", '{"answer":"a"}'],
61
+ answers=["D", "A"],
62
+ )
63
+ assert out[0] == pytest.approx(1.0)
64
+ assert out[1] == pytest.approx(1.0)
65
+
66
+
67
+ def test_reward_length_penalty_only_past_cap():
68
+ """A correct-but-long completion is penalized by ~0.001/char past the cap;
69
+ a short one is not. Rationale CONTENT is never scored — only length."""
70
+ r = MMLUFormatReward(rationale_char_cap=20, length_penalty_per_char=0.001)
71
+ short = "Answer: B" # under cap
72
+ long = "x" * 120 + " Answer: B" # ~130 chars, 110 over cap
73
+ out = r(completions=[short, long], answers=["B", "B"])
74
+ assert out[0] == pytest.approx(1.0)
75
+ # 130 - 20 = 110 over => penalty 0.110; reward 1.0 - 0.110
76
+ assert out[1] < 1.0
77
+ assert out[1] == pytest.approx(1.0 - 0.001 * (len(long) - 20))
78
+
79
+
80
+ def test_reward_always_C_exploit_is_detectable():
81
+ """An 'always C' policy that happens to be right when gold==C scores well on
82
+ those items, but the logged option distribution reveals the exploit."""
83
+ r = MMLUFormatReward()
84
+ completions = [f"Answer: C" for _ in range(10)]
85
+ golds = ["C", "A", "B", "C", "D", "A", "C", "B", "C", "D"]
86
+ r(completions=completions, answers=golds)
87
+ report = r.exploit_report()
88
+ assert report["most_common"] == "C"
89
+ # Every parsed answer was C => fraction 1.0 — the exploit signature.
90
+ assert report["max_fraction"] == pytest.approx(1.0)
91
+ assert report["counts"] == {"C": 10}
92
+
93
+
94
+ def test_reward_requires_answers():
95
+ r = MMLUFormatReward()
96
+ with pytest.raises(ValueError, match="requires `answers`"):
97
+ r(completions=["Answer: A"])
98
+
99
+
100
+ def test_randomize_options_tracks_label_remap_and_updates_gold():
101
+ item = {"question": "q", "options": ["w", "x", "y", "z"], "answer": "A"}
102
+ shuffled, remap = randomize_options(item, seed=7)
103
+ # All four letters map to four distinct new letters (a permutation).
104
+ assert sorted(remap.keys()) == ["A", "B", "C", "D"]
105
+ assert sorted(remap.values()) == ["A", "B", "C", "D"]
106
+ # The gold option's text ("w", originally A) now lives at its remapped letter.
107
+ new_gold_letter = shuffled["answer"]
108
+ new_gold_idx = ord(new_gold_letter) - ord("A")
109
+ assert shuffled["options"][new_gold_idx] == "w"
110
+ assert remap["A"] == new_gold_letter
111
+ # Determinism.
112
+ shuffled2, remap2 = randomize_options(item, seed=7)
113
+ assert remap == remap2 and shuffled["options"] == shuffled2["options"]
114
+
115
+
116
+ # ===========================================================================
117
+ # dual_kl_logger
118
+ # ===========================================================================
119
+
120
+ def test_dual_kl_self_is_zero():
121
+ """KL(p‖p) == 0 for both diagnostics."""
122
+ logits = torch.randn(2, 5, 16)
123
+ out = dual_kl_logger(logits, logits, logits)
124
+ assert out["kl_to_altered_init"] == pytest.approx(0.0, abs=1e-6)
125
+ assert out["kl_to_base"] == pytest.approx(0.0, abs=1e-6)
126
+
127
+
128
+ def test_dual_kl_grows_as_policy_moves():
129
+ """As the policy distribution moves further from a fixed reference, the KL
130
+ grows monotonically. Both diagnostics are non-negative."""
131
+ torch.manual_seed(0)
132
+ ref = torch.randn(1, 4, 16)
133
+ base = torch.randn(1, 4, 16)
134
+
135
+ near = ref + 0.1 * torch.randn_like(ref)
136
+ far = ref + 2.0 * torch.randn_like(ref)
137
+
138
+ kl_near = dual_kl_logger(near, ref, base)["kl_to_altered_init"]
139
+ kl_far = dual_kl_logger(far, ref, base)["kl_to_altered_init"]
140
+
141
+ assert kl_near >= -1e-9
142
+ assert kl_far > kl_near, f"KL should grow as policy moves: {kl_near} -> {kl_far}"
143
+
144
+
145
+ def test_dual_kl_mask_restricts_tokens():
146
+ """A token mask restricts the mean to the masked answer+reasoning tokens."""
147
+ torch.manual_seed(1)
148
+ policy = torch.randn(1, 4, 8)
149
+ ref = torch.randn(1, 4, 8)
150
+ base = torch.randn(1, 4, 8)
151
+ mask = torch.tensor([[1, 1, 0, 0]])
152
+ out = dual_kl_logger(policy, ref, base, mask=mask)
153
+ # Masked-all-zero => 0.0 (guarded), nonzero mask => finite non-negative.
154
+ assert out["kl_to_altered_init"] >= -1e-9
155
+ zero = dual_kl_logger(policy, ref, base, mask=torch.zeros(1, 4))
156
+ assert zero["kl_to_altered_init"] == 0.0
157
+ assert zero["kl_to_base"] == 0.0
158
+
159
+
160
+ # ===========================================================================
161
+ # channel_ladder_configs
162
+ # ===========================================================================
163
+
164
+ def test_ladder_arms_and_order():
165
+ arms = channel_ladder_configs()
166
+ assert [a["arm"] for a in arms] == ["A0", "A1", "A2", "A3", "A4"]
167
+
168
+
169
+ def test_ladder_a0_is_no_rl_sentinel():
170
+ a0 = channel_ladder_configs()[0]
171
+ assert a0["arm"] == "A0"
172
+ assert a0["alpha_sdpo"] is None
173
+ assert a0["beta_replay"] is None
174
+ assert a0["kl_beta"] is None
175
+
176
+
177
+ def test_ladder_a1_both_off():
178
+ a1 = channel_ladder_configs()[1]
179
+ assert a1["alpha_sdpo"] == 0.0
180
+ assert a1["beta_replay"] == 0.0
181
+ assert a1["kl_beta"] == 0.02
182
+
183
+
184
+ def test_ladder_a2_sdpo_only():
185
+ a2 = channel_ladder_configs()[2]
186
+ assert a2["alpha_sdpo"] == 0.02
187
+ assert a2["beta_replay"] == 0.0
188
+ assert a2["kl_beta"] == 0.02
189
+
190
+
191
+ def test_ladder_a3_replay_only():
192
+ a3 = channel_ladder_configs()[3]
193
+ assert a3["alpha_sdpo"] == 0.0
194
+ assert a3["beta_replay"] == 0.05
195
+ assert a3["kl_beta"] == 0.02
196
+
197
+
198
+ def test_ladder_a4_combined():
199
+ a4 = channel_ladder_configs()[4]
200
+ assert a4["alpha_sdpo"] == 0.02
201
+ assert a4["beta_replay"] == 0.05
202
+
203
+
204
+ # ===========================================================================
205
+ # B4 — the SDPO channel actually FIRES (NONZERO) with REAL collator indices
206
+ # ===========================================================================
207
+ #
208
+ # HONEST NOTE ON STUB-VS-REAL (ADR-013 B4 acceptance):
209
+ #
210
+ # This proof uses the same TinyLM stub pattern as
211
+ # trainer/tests/test_sdpo_alignment_indices.py, NOT a real Qwen checkpoint
212
+ # (kept offline/CPU and deterministic). The alignment indices are REAL: they are
213
+ # built by the production ComposerDataCollator from a trace that HAS an error
214
+ # turn (so ctx_teacher_input_ids + student/teacher_response_idx are genuinely
215
+ # emitted by the shipped collator, exactly as in a real run).
216
+ #
217
+ # Why we must perturb the student tokens to get a NONZERO loss: the collator's
218
+ # placeholder-alignment trick makes student and teacher carry the SAME token ids
219
+ # at the SAME absolute positions at valid aligned indices, so a deterministic
220
+ # stub yields JSD≈0 there (the CORRECT answer for a perfectly-aligned identical
221
+ # model — see that test's gate-3 note). To prove the channel genuinely GATHERS
222
+ # the aligned positions and computes nonzero divergence, we make the student's
223
+ # input_ids DIFFER from the teacher's at exactly the aligned response positions
224
+ # — this mimics the hint actually changing the recovery tokens (the real-world
225
+ # case where SDPO has a signal to distill). With a position-dependent stub,
226
+ # different aligned token ids => different logits => provably NONZERO JSD on a
227
+ # grad path, through the real collator-built indices.
228
+
229
+ from composer_replication.trainer.data_collator import ( # noqa: E402
230
+ CollatorConfig,
231
+ ComposerDataCollator,
232
+ )
233
+
234
+
235
+ class _StubTok:
236
+ """Word-level deterministic tokenizer; apply_chat_template space-joins."""
237
+
238
+ pad_token_id = 0
239
+
240
+ def __init__(self) -> None:
241
+ self._v: dict[str, int] = {"<pad>": 0, "<bos>": 1, "<eos>": 2}
242
+
243
+ def _id(self, w: str) -> int:
244
+ if w not in self._v:
245
+ self._v[w] = len(self._v)
246
+ return self._v[w]
247
+
248
+ def __call__(self, text, **_k):
249
+ return {"input_ids": [self._id(w) for w in text.split()] if text else []}
250
+
251
+ def apply_chat_template(self, messages, tokenize=True, **_k): # noqa: ARG002
252
+ return [self._id(w) for w in " ".join(m.get("content", "") for m in messages).split()]
253
+
254
+
255
+ class _TinyLM(torch.nn.Module):
256
+ """Position-dependent minimal model: model(input_ids=...).logits."""
257
+
258
+ def __init__(self, vocab: int = 64, hidden: int = 8, max_pos: int = 512):
259
+ super().__init__()
260
+ torch.manual_seed(0)
261
+ self.embed = torch.nn.Embedding(vocab, hidden)
262
+ self.pos = torch.nn.Embedding(max_pos, hidden)
263
+ self.head = torch.nn.Linear(hidden, vocab)
264
+
265
+ def forward(self, input_ids: torch.Tensor):
266
+ T = input_ids.size(1)
267
+ positions = torch.arange(T, device=input_ids.device).unsqueeze(0)
268
+ h = self.embed(input_ids) + self.pos(positions)
269
+
270
+ class _Out:
271
+ pass
272
+
273
+ out = _Out()
274
+ out.logits = self.head(h)
275
+ return out
276
+
277
+
278
+ def _hint_gen(_kind, _meta):
279
+ return "HINT search before reading"
280
+
281
+
282
+ def _error_trace(trace_id: str, recovery: str = "let me use a real tool instead now"):
283
+ return {
284
+ "trace_id": trace_id,
285
+ "turns": [
286
+ {"role": "user", "content": "do the task now"},
287
+ {"role": "user", "content": "tool not found error occurred"},
288
+ {
289
+ "role": "assistant",
290
+ "content": recovery,
291
+ "tool_error": "tool_not_found",
292
+ "error_meta": {},
293
+ },
294
+ ],
295
+ "final_reward": 0.0,
296
+ }
297
+
298
+
299
+ def _make_sdpo_trainer(alpha_sdpo: float):
300
+ from composer_replication.trainer.composer_trainer import ComposerReplicationTrainer
301
+
302
+ obj = ComposerReplicationTrainer.__new__(ComposerReplicationTrainer)
303
+ obj.alpha_sdpo = alpha_sdpo
304
+ obj.sdpo_jsd_beta = 0.5
305
+ obj.sdpo_temperature = 1.0
306
+ obj.sdpo_token_clip = None
307
+ obj.strict_sdpo_alignment = True # production default
308
+ return obj
309
+
310
+
311
+ def test_b4_sdpo_fires_nonzero_with_real_collator_indices():
312
+ """B4: with REAL collator-built alignment indices and the student tokens
313
+ differing from the teacher at the aligned response positions (hint changed
314
+ the recovery tokens), the SDPO channel gathers those positions and produces
315
+ a NONZERO JSD on a grad path — proving the channel actually FIRES."""
316
+ tok = _StubTok()
317
+ cfg = CollatorConfig(hint_generator=_hint_gen, enable_replay_dpo=False)
318
+ collator = ComposerDataCollator(tokenizer=tok, config=cfg)
319
+ batch = collator([_error_trace("b4-fires")])
320
+
321
+ # Sanity: the collator genuinely emitted error-site teacher context + indices.
322
+ assert batch["ctx_teacher_input_ids"].numel() > 0
323
+ s_idx = batch["student_response_idx"]
324
+ t_idx = batch["teacher_response_idx"]
325
+ s_valid = batch["student_response_valid"]
326
+ assert int(s_valid.sum()) > 0, "no valid aligned positions — collator emitted nothing"
327
+
328
+ # Perturb the STUDENT tokens at the aligned response positions so they differ
329
+ # from the teacher's tokens there (the hint changed the recovery tokens). We
330
+ # keep the REAL collator-built indices; only the student input_ids change.
331
+ student_ids = batch["input_ids"].clone()
332
+ vocab_ceiling = int(
333
+ max(batch["input_ids"].max(), batch["ctx_teacher_input_ids"].max())
334
+ ) + 8
335
+ for b in range(s_idx.shape[0]):
336
+ for k in range(s_idx.shape[1]):
337
+ if bool(s_valid[b, k]):
338
+ pos = int(s_idx[b, k])
339
+ # bump to a different, in-vocab token id (deterministic).
340
+ student_ids[b, pos] = (int(student_ids[b, pos]) + 3) % vocab_ceiling
341
+ batch["input_ids"] = student_ids
342
+
343
+ model = _TinyLM(vocab=max(vocab_ceiling, 8))
344
+ obj = _make_sdpo_trainer(alpha_sdpo=0.02) # A2 config (SDPO-only small)
345
+
346
+ loss = obj._compute_sdpo_loss(model, batch)
347
+ val = float(loss.detach())
348
+
349
+ assert val == val and val not in (float("inf"), float("-inf")), "loss not finite"
350
+ assert loss.requires_grad, "SDPO loss must be on a grad path"
351
+ assert val > 1e-6, (
352
+ f"SDPO channel did not fire: JSD={val} (expected NONZERO once the "
353
+ "aligned student/teacher tokens differ). The channel must gather the "
354
+ "real collator indices and compute a positive divergence."
355
+ )
356
+
357
+ # Prove it is differentiable end-to-end: backward populates a real gradient.
358
+ (obj.alpha_sdpo * loss).backward()
359
+ grad_norm = sum(
360
+ float(p.grad.norm()) for p in model.parameters() if p.grad is not None
361
+ )
362
+ assert grad_norm > 0.0, "no gradient flowed from the SDPO loss into the model"
docs/ALTERED_MINDS_TIE_IN.md CHANGED
@@ -79,16 +79,43 @@ Fits inside the user's existing $400 altered-minds budget.
79
 
80
  ### Phase 3 — GRPO with the framework
81
 
82
- Run `composer_replication.recipes.trl.ComposerReplicationTrainer` with:
83
- - **Channel 1 (GRPO)**: turned ON, reward = MMLU letter-correctness
84
- - **Channel 2 (SDPO/OPSD)**: turned ON at α=0.2, hint-conditioned
85
- against the altered model's own forward pass
86
- - **Channel 3 (trace-replay DPO)**: turned ON at β=0.4, against the
87
- Phase-2 datasets
88
-
89
- Train for ~500 steps on a single GPU (Qwen-0.5B feasibility-test
90
- already confirmed in the framework; for Llama-8B, use Modal + the
91
- framework's `ServerlessExecutor` per ADR-005 local 5090 is too small).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
  ### Phase 4 — re-evaluate
94
 
 
79
 
80
  ### Phase 3 — GRPO with the framework
81
 
82
+ > **⚠️ SUPERSEDED by [ADR-013](adrs/ADR-013-lma-integration-channel-ladder.md).**
83
+ > The original all-channels-on combined recipe (α=0.2, β=0.4) is **not used**.
84
+ > A cross-family research critique (2026-05-29) found a combined-first run
85
+ > **scientifically uninterpretable**: it confounds four effects (task RL,
86
+ > self-distillation of altered reasoning, frontier-teacher imitation, KL
87
+ > anchoring), so any observed change in the alteration signature cannot be
88
+ > attributed to a channel. Worse, **SDPO against the altered model's own
89
+ > hint-conditioned forward pass is the channel most likely to AMPLIFY the
90
+ > distortion** (teacher == student-family; if hints add no independent
91
+ > information, the optimum is to imitate the altered conditional distribution,
92
+ > sharpening a soft bias into a hard preference). SDPO here is therefore an
93
+ > *experimental intervention*, not a benign stabilizer.
94
+
95
+ **Use the isolated-channel ladder (ADR-013) instead** — sweep arms A0–A4 with
96
+ identical seeds/prompts so each channel's effect is attributable:
97
+
98
+ | Arm | alpha_sdpo | beta_replay | Purpose |
99
+ |---|---|---|---|
100
+ | A0 | — | — | altered SFT, no RL (control) |
101
+ | A1 | 0.0 | 0.0 | GRPO-only baseline |
102
+ | A2 | **0.02** | 0.0 | +SDPO small (amplification probe) |
103
+ | A3 | 0.0 | **0.05** | +replay-DPO small (washout probe) |
104
+ | A4 | 0.02 | 0.05 | combined — only after A1–A3 interpretable |
105
+
106
+ `kl_beta=0.02` (KL-to-altered-init) on every RL arm, adaptive to 0.01–0.03
107
+ nats/token; hard-stop/LR-cut if KL > ~0.08. The framework provides the ladder
108
+ via `composer_replication.integrations.altered_minds.channel_ladder_configs()`,
109
+ the structured `MMLUFormatReward` (scores the final answer letter + format
110
+ only — never rationale style, so distorted-but-persuasive reasoning is not
111
+ rewarded), and `dual_kl_logger` (logs KL-to-altered-init **and** KL-to-base each
112
+ step — the washout-vs-amplification instrument).
113
+
114
+ Train for ~500 steps per arm on a single GPU (Qwen-0.5B feasibility-test
115
+ already confirmed; for Llama-8B, use Modal + the framework's `ServerlessExecutor`
116
+ per ADR-005 — local 5090 is too small). The real 8B/LMA-checkpoint run remains
117
+ **user-gated** (it spends grant budget) — ADR-013 ships the capability, proven
118
+ CPU-only on a small model (`examples/altered_minds_channel_ladder/`).
119
 
120
  ### Phase 4 — re-evaluate
121
 
docs/VISION_VALIDATION.md CHANGED
@@ -1,5 +1,12 @@
1
  # Vision Validation: Does the Framework Encapsulate the Original Brief?
2
 
 
 
 
 
 
 
 
3
  > **Status:** Self-audit, 2026-05-25 (Wave 6).
4
  > **Question:** Does what we've built reflect what was originally asked for, or did we drift?
5
  > **Method:** Recover original brief verbatim → atomic-clause decomposition → traceability matrix → adversarial self-review → user-journey simulation → concrete pass/fail scorecard with gap-closing actions.
 
1
  # Vision Validation: Does the Framework Encapsulate the Original Brief?
2
 
3
+ > **## Status as of 2026-05-29**
4
+ > The framework is past-skeleton: 8 subpackages (`composer_replication/*`), 210 passing tests, and operational end-to-end examples (`gsm8k_grpo`, `sdpo_with_real_traces_production`). The 3-channel loss, layered hint-generation, trace-ingestion, and DiLoCo have all shipped and been cross-family reviewed.
5
+ >
6
+ > **Two remaining honest gaps:**
7
+ > 1. Docker/TorchForge substrate E2E is hardware-blocked (lacking local multi-GPU rig for the orchestrator layer).
8
+ > 2. Real LMA full-scale run (8B model, 10k SWE-bench traces) is user-budget-gated.
9
+
10
  > **Status:** Self-audit, 2026-05-25 (Wave 6).
11
  > **Question:** Does what we've built reflect what was originally asked for, or did we drift?
12
  > **Method:** Recover original brief verbatim → atomic-clause decomposition → traceability matrix → adversarial self-review → user-journey simulation → concrete pass/fail scorecard with gap-closing actions.
examples/altered_minds_channel_ladder/README.md ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # altered_minds_channel_ladder — B4 SDPO-fires proof (ADR-013)
2
+
3
+ CPU end-to-end proof that the **SDPO channel actually FIRES (nonzero)** on a
4
+ batch built by the production `ComposerDataCollator` with **real, collator-built
5
+ alignment indices**, in the **A2** isolated-channel-ladder config
6
+ (`alpha_sdpo=0.02`).
7
+
8
+ ## Why this exists
9
+
10
+ `examples/composer_grpo_sdpo_smoke` proves the SDPO channel is *wired* into a
11
+ live TRL Dr.GRPO loop, but its toy synthetic rollouts carry no error sites, so
12
+ `_compute_sdpo_loss` returns `0` — the channel never actually fires. This script
13
+ closes that gap: it feeds a trace that **has an error turn**, so the collator
14
+ emits `ctx_teacher_input_ids` + `student_response_idx`/`teacher_response_idx`,
15
+ and the SDPO JSD is proven **nonzero** on a differentiable grad path.
16
+
17
+ ## Proof achieved: `TinyLM-stub-with-differing-tokens`
18
+
19
+ - **Alignment indices: REAL** — emitted by the shipped `ComposerDataCollator`
20
+ from a genuine error-turn trace, exactly as in a real run.
21
+ - **Model: a deterministic position-dependent `TinyLM` stub** (CPU, no
22
+ download), the same pattern as
23
+ `composer_replication/trainer/tests/test_sdpo_alignment_indices.py`.
24
+ - **Why student tokens are perturbed:** the collator's placeholder-alignment
25
+ trick makes student and teacher carry identical tokens at identical positions
26
+ at the valid aligned indices, so a deterministic stub yields `JSD≈0` there
27
+ (the *correct* answer for a perfectly-aligned identical model). To prove the
28
+ channel genuinely **gathers** the aligned positions and computes a real
29
+ divergence, the student's `input_ids` are made to **differ** from the
30
+ teacher's at exactly those aligned positions — mimicking the hint actually
31
+ changing the recovery tokens (the real-world case where SDPO has signal to
32
+ distill). Different aligned tokens ⇒ different logits ⇒ provably **NONZERO**
33
+ JSD.
34
+
35
+ This is the honest, deterministic CPU proof. Loading a real Qwen2.5-0.5B
36
+ checkpoint is **not required** for the B4 gate and is **not** the same as loading
37
+ an LMA checkpoint (still user-gated, ADR-013 out-of-scope).
38
+
39
+ ## Run
40
+
41
+ ```bash
42
+ cd <repo> && .venv/bin/python examples/altered_minds_channel_ladder/run.py
43
+ ```
44
+
45
+ Optional: `ALTERED_MINDS_REAL_MODEL=1` swaps the stub for a cached
46
+ Qwen2.5-0.5B-Instruct (offline, much slower on CPU). The same token-perturbation
47
+ is still required for a nonzero signal.
48
+
49
+ Exit `0` = PASS (SDPO fired nonzero), `1` = FAIL, `2` = SKIP (deps unavailable).
50
+
51
+ The automated assertion lives in
52
+ `composer_replication/integrations/altered_minds/tests/test_channel_ladder.py::test_b4_sdpo_fires_nonzero_with_real_collator_indices`.
examples/altered_minds_channel_ladder/run.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """B4 end-to-end CPU proof: the SDPO channel actually FIRES (NONZERO) on a real
2
+ collator-built batch with genuine alignment indices (ADR-013).
3
+
4
+ The existing examples/composer_grpo_sdpo_smoke proves the SDPO channel is *wired*
5
+ into a live TRL Dr.GRPO loop, but its toy synthetic rollouts carry no error
6
+ sites, so _compute_sdpo_loss returns 0 (the channel never actually fires). This
7
+ script closes that gap: it builds a REAL ComposerDataCollator batch from a trace
8
+ that HAS an error turn — so ctx_teacher_input_ids + student/teacher_response_idx
9
+ are emitted by the shipped collator — and proves the SDPO JSD is NONZERO over
10
+ >=1 step, in the A2 ladder config (alpha_sdpo=0.02).
11
+
12
+ PROOF ACHIEVED: stub-with-differing-tokens (NOT a real Qwen checkpoint).
13
+
14
+ - Alignment indices: REAL (production ComposerDataCollator, real error turn).
15
+ - Model: a deterministic position-dependent TinyLM stub (CPU, no download),
16
+ the same pattern used by trainer/tests/test_sdpo_alignment_indices.py.
17
+ - Why perturb student tokens: the collator's placeholder-alignment trick makes
18
+ student & teacher carry identical tokens at identical positions at the valid
19
+ aligned indices, so a deterministic stub yields JSD≈0 there (correct for a
20
+ perfectly-aligned identical model). To prove the channel GATHERS the aligned
21
+ positions and computes a real divergence, the student's input_ids are made to
22
+ DIFFER from the teacher's at exactly those aligned positions — mimicking the
23
+ hint actually changing the recovery tokens (the real-world case where SDPO
24
+ has signal to distill). Different aligned tokens => different logits =>
25
+ provably NONZERO JSD, on a differentiable grad path.
26
+
27
+ To run the SAME assertion against a real Qwen2.5-0.5B-Instruct (if cached
28
+ offline), set ALTERED_MINDS_REAL_MODEL=1 — note that even with a real model the
29
+ NONZERO signal still requires the aligned student/teacher tokens to differ, so
30
+ this script keeps the same token-perturbation; the real-model path only swaps
31
+ the stub for the HF model and is much slower on CPU.
32
+
33
+ Exit 0 = PASS (SDPO fired nonzero), 1 = FAIL, 2 = SKIP (deps unavailable).
34
+ """
35
+ from __future__ import annotations
36
+
37
+ import os
38
+ import sys
39
+
40
+
41
+ def _build_tiny_lm(vocab: int):
42
+ import torch
43
+
44
+ class _TinyLM(torch.nn.Module):
45
+ def __init__(self, vocab: int = 64, hidden: int = 8, max_pos: int = 512):
46
+ super().__init__()
47
+ torch.manual_seed(0)
48
+ self.embed = torch.nn.Embedding(vocab, hidden)
49
+ self.pos = torch.nn.Embedding(max_pos, hidden)
50
+ self.head = torch.nn.Linear(hidden, vocab)
51
+
52
+ def forward(self, input_ids):
53
+ T = input_ids.size(1)
54
+ positions = torch.arange(T, device=input_ids.device).unsqueeze(0)
55
+ h = self.embed(input_ids) + self.pos(positions)
56
+
57
+ class _Out:
58
+ pass
59
+
60
+ out = _Out()
61
+ out.logits = self.head(h)
62
+ return out
63
+
64
+ return _TinyLM(vocab=max(vocab, 8))
65
+
66
+
67
+ class _StubTok:
68
+ pad_token_id = 0
69
+
70
+ def __init__(self) -> None:
71
+ self._v = {"<pad>": 0, "<bos>": 1, "<eos>": 2}
72
+
73
+ def _id(self, w: str) -> int:
74
+ if w not in self._v:
75
+ self._v[w] = len(self._v)
76
+ return self._v[w]
77
+
78
+ def __call__(self, text, **_k):
79
+ return {"input_ids": [self._id(w) for w in text.split()] if text else []}
80
+
81
+ def apply_chat_template(self, messages, tokenize=True, **_k): # noqa: ARG002
82
+ return [
83
+ self._id(w)
84
+ for w in " ".join(m.get("content", "") for m in messages).split()
85
+ ]
86
+
87
+
88
+ def _hint_gen(_kind, _meta):
89
+ return "HINT search before reading"
90
+
91
+
92
+ def _error_trace():
93
+ return {
94
+ "trace_id": "b4-channel-ladder",
95
+ "turns": [
96
+ {"role": "user", "content": "do the task now"},
97
+ {"role": "user", "content": "tool not found error occurred"},
98
+ {
99
+ "role": "assistant",
100
+ "content": "let me use a real working tool instead now",
101
+ "tool_error": "tool_not_found",
102
+ "error_meta": {},
103
+ },
104
+ ],
105
+ "final_reward": 0.0,
106
+ }
107
+
108
+
109
+ def main() -> int:
110
+ os.environ.setdefault("HF_HUB_OFFLINE", "1")
111
+ os.environ.setdefault("TRANSFORMERS_OFFLINE", "1")
112
+
113
+ try:
114
+ import torch # noqa: F401
115
+
116
+ from composer_replication.integrations.altered_minds import (
117
+ channel_ladder_configs,
118
+ )
119
+ from composer_replication.trainer.composer_trainer import (
120
+ ComposerReplicationTrainer,
121
+ make_dr_grpo_config,
122
+ )
123
+ from composer_replication.trainer.data_collator import (
124
+ CollatorConfig,
125
+ ComposerDataCollator,
126
+ )
127
+ except Exception as e: # noqa: BLE001
128
+ print(f"SKIP: import failed: {e!r}")
129
+ return 2
130
+
131
+ # A2 arm = +SDPO small (alpha_sdpo=0.02), the amplification probe.
132
+ a2 = next(a for a in channel_ladder_configs() if a["arm"] == "A2")
133
+ print(f"[b4] ladder arm A2: alpha_sdpo={a2['alpha_sdpo']} "
134
+ f"beta_replay={a2['beta_replay']} kl_beta={a2['kl_beta']}")
135
+
136
+ # make_dr_grpo_config is exercised to prove the config wiring is intact
137
+ # (the actual TLM stub forward does not need a GRPOConfig, but a real A2
138
+ # runner would pass this through to ComposerReplicationTrainer).
139
+ try:
140
+ cfg = make_dr_grpo_config(output_dir="/tmp/b4_ladder_out", report_to=[])
141
+ print(f"[b4] Dr.GRPO config OK: loss_type={cfg.loss_type} "
142
+ f"scale_rewards={cfg.scale_rewards} num_iterations={cfg.num_iterations}")
143
+ except Exception as e: # noqa: BLE001
144
+ print(f"[b4] (config build skipped: {e!r})")
145
+
146
+ # --- REAL collator-built batch with a genuine error turn ---
147
+ tok = _StubTok()
148
+ collator = ComposerDataCollator(
149
+ tokenizer=tok,
150
+ config=CollatorConfig(hint_generator=_hint_gen, enable_replay_dpo=False),
151
+ )
152
+ batch = collator([_error_trace()])
153
+
154
+ if batch.get("ctx_teacher_input_ids") is None or batch["ctx_teacher_input_ids"].numel() == 0:
155
+ print("FAIL: collator emitted no error-site teacher context.")
156
+ return 1
157
+ s_idx = batch["student_response_idx"]
158
+ s_valid = batch["student_response_valid"]
159
+ if int(s_valid.sum()) == 0:
160
+ print("FAIL: no valid aligned response positions.")
161
+ return 1
162
+ print(f"[b4] collator emitted real alignment indices: "
163
+ f"student_response_idx shape={tuple(s_idx.shape)}, "
164
+ f"valid positions={int(s_valid.sum())}")
165
+
166
+ # --- Make the student tokens differ from teacher at aligned positions ---
167
+ student_ids = batch["input_ids"].clone()
168
+ vocab_ceiling = int(
169
+ max(batch["input_ids"].max(), batch["ctx_teacher_input_ids"].max())
170
+ ) + 8
171
+ for b in range(s_idx.shape[0]):
172
+ for k in range(s_idx.shape[1]):
173
+ if bool(s_valid[b, k]):
174
+ pos = int(s_idx[b, k])
175
+ student_ids[b, pos] = (int(student_ids[b, pos]) + 3) % vocab_ceiling
176
+ batch["input_ids"] = student_ids
177
+
178
+ real_model = os.environ.get("ALTERED_MINDS_REAL_MODEL") == "1"
179
+ if real_model:
180
+ try:
181
+ from transformers import AutoModelForCausalLM
182
+ model_id = os.environ.get("SMOKE_MODEL", "Qwen/Qwen2.5-0.5B-Instruct")
183
+ print(f"[b4] loading real model {model_id} (CPU, slow) ...")
184
+ model = AutoModelForCausalLM.from_pretrained(model_id)
185
+ print("[b4] real model loaded; proof path = REAL-MODEL")
186
+ except Exception as e: # noqa: BLE001
187
+ print(f"[b4] real model unavailable ({e!r}); falling back to TinyLM stub")
188
+ model = _build_tiny_lm(vocab_ceiling)
189
+ real_model = False
190
+ else:
191
+ model = _build_tiny_lm(vocab_ceiling)
192
+
193
+ # --- A2 config: SDPO-only small (alpha_sdpo=0.02), strict alignment ---
194
+ obj = ComposerReplicationTrainer.__new__(ComposerReplicationTrainer)
195
+ obj.alpha_sdpo = float(a2["alpha_sdpo"])
196
+ obj.sdpo_jsd_beta = 0.5
197
+ obj.sdpo_temperature = 1.0
198
+ obj.sdpo_token_clip = None
199
+ obj.strict_sdpo_alignment = True
200
+
201
+ loss = obj._compute_sdpo_loss(model, batch)
202
+ val = float(loss.detach())
203
+ print("=" * 64)
204
+ print(f" proof path: {'REAL-MODEL' if real_model else 'TinyLM-stub-with-differing-tokens'}")
205
+ print(f" SDPO JSD (sdpo_kl): {val:.6f}")
206
+ print(f" requires_grad: {loss.requires_grad}")
207
+
208
+ if not (val == val) or val in (float("inf"), float("-inf")):
209
+ print(" RESULT: FAIL ❌ (loss not finite)")
210
+ return 1
211
+ if val <= 1e-6:
212
+ print(" RESULT: FAIL ❌ (SDPO channel did not fire — JSD ~0)")
213
+ return 1
214
+
215
+ (obj.alpha_sdpo * loss).backward()
216
+ grad_norm = sum(
217
+ float(p.grad.norm()) for p in model.parameters() if p.grad is not None
218
+ )
219
+ print(f" grad norm into model: {grad_norm:.6f}")
220
+ if grad_norm <= 0.0:
221
+ print(" RESULT: FAIL ❌ (no gradient flowed from SDPO loss)")
222
+ return 1
223
+
224
+ print(" RESULT: PASS ✅ (SDPO channel FIRED nonzero via real collator indices)")
225
+ return 0
226
+
227
+
228
+ if __name__ == "__main__":
229
+ sys.exit(main())