Codeseys's picture
Wave 14: close every Wave 13 review finding + 4 documentation files; Wave 14b: real PRIME-RL parity + multi-process DiLoCo convergence
d9dd3a5
"""DJNormalizer — data-juicer adapter for replaysim DPO output.
Wraps the framework's `extract_dpo_pairs` output in a data-juicer op-graph.
The op-graph runs entirely CPU-side and applies length filtering, chat-
template validation, and per-conversation deduplication. Ops are loaded
from a YAML recipe so users can swap normalization strategies without
touching framework code.
Default recipe lives at:
composer_replication/recipes/replaysim/default.yaml
The data-juicer dependency is optional (pulled by the `[replaysim]` extra).
This file imports it lazily inside method bodies so that the package
imports cleanly without it.
Source-of-truth shape (from `composer_replication.teacher_replay`):
DPOPair = TypedDict("DPOPair", {
"state_id": str,
"state_messages": list[dict], # conversation up to this step
"chosen": str, # teacher-consensus action
"rejected": str, # student action
"n_teachers_agreeing": int,
})
The normalizer does NOT require chosen_teacher / rejected_teacher fields —
those don't exist in the real DPOPair shape.
"""
from __future__ import annotations
import asyncio
import json
import os
import tempfile
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Iterable, cast
from composer_replication.teacher_replay import (
DPOPair,
TeacherCallResult,
extract_dpo_pairs,
replay_trace,
)
@dataclass
class NormalizedDPOPair:
"""A DPOPair that has passed through normalization. Same data as
DPOPair but reshaped into chat-messages format (matching data-juicer's
native multi-turn op support) plus a metadata dict tracking which
ops fired.
"""
state_id: str
"""Identifier for the trace state (turn) this pair came from."""
state_messages: list[dict[str, Any]]
"""The conversation context up to (and including) this step's user prompt."""
chosen_messages: list[dict[str, Any]]
"""The chosen completion as a chat-messages list (one assistant turn)."""
rejected_messages: list[dict[str, Any]]
"""The rejected completion as a chat-messages list (one assistant turn)."""
n_teachers_agreeing: int
"""How many teachers agreed on the chosen action (preserved from DPOPair)."""
metadata: dict[str, Any]
"""Op-graph provenance: which ops fired, what they changed."""
def _dpo_pair_to_dj_record(pair: DPOPair | dict[str, Any]) -> dict[str, Any]:
"""Convert a DPOPair (or dict-shaped equivalent) into a data-juicer record.
The record carries TWO shapes for chosen/rejected so that data-juicer ops
that expect string-typed text fields (e.g. ``text_length_filter``,
``words_num_filter``, ``special_characters_filter``,
``document_deduplicator``) work alongside chat-aware ops:
- ``chosen`` / ``rejected``: flat strings (drives the standard text ops
that read string fields via ``text_keys``).
- ``chosen_messages`` / ``rejected_messages``: chat-messages list
(one assistant turn each), preserving the multi-turn-aware shape.
The ``messages`` field carries the conversation context (matches
data-juicer's ``messages`` convention for chat-aware filters).
"""
p = cast(dict[str, Any], pair)
chosen_str = p.get("chosen", "") or ""
rejected_str = p.get("rejected", "") or ""
return {
"state_id": p.get("state_id", ""),
"messages": p.get("state_messages", []),
# Flat-string shape for length/word/special-char/dedup filters
# that expect text_keys to point at strings.
"chosen": chosen_str,
"rejected": rejected_str,
# Chat-messages shape for chat-aware ops and the NormalizedDPOPair
# round-trip.
"chosen_messages": [{"role": "assistant", "content": chosen_str}],
"rejected_messages": [{"role": "assistant", "content": rejected_str}],
"n_teachers_agreeing": p.get("n_teachers_agreeing", 0),
}
def _dj_record_to_normalized(rec: dict[str, Any]) -> NormalizedDPOPair:
"""Inverse — convert a data-juicer record back to NormalizedDPOPair.
Tolerates records that only carry one of the two shapes:
- If ``chosen_messages``/``rejected_messages`` are present, use them
directly.
- Otherwise wrap the flat-string ``chosen``/``rejected`` fields into
a single-assistant-turn messages list. This handles the case where
a data-juicer op rewrites the string field but doesn't touch the
messages field.
"""
def _to_messages(val: Any, fallback_str: Any) -> list[dict[str, Any]]:
if isinstance(val, list) and val:
return val # already chat-messages shape
if isinstance(fallback_str, str) and fallback_str:
return [{"role": "assistant", "content": fallback_str}]
if isinstance(fallback_str, list):
# Edge case: someone put the messages list in the flat field.
return fallback_str
return []
chosen_messages = _to_messages(
rec.get("chosen_messages"), rec.get("chosen", "")
)
rejected_messages = _to_messages(
rec.get("rejected_messages"), rec.get("rejected", "")
)
return NormalizedDPOPair(
state_id=rec.get("state_id", ""),
state_messages=rec.get("messages", []),
chosen_messages=chosen_messages,
rejected_messages=rejected_messages,
n_teachers_agreeing=rec.get("n_teachers_agreeing", 0),
metadata=rec.get("__dj_meta__", {}),
)
class DJNormalizer:
"""data-juicer-backed normalizer for DPO pairs.
Args:
recipe_path: path to a data-juicer YAML recipe. If None, uses the
framework's default recipe (length filter + chat-template
validation + per-conversation dedup).
skip_dj: if True, the normalizer becomes a passthrough — useful
for test environments without data-juicer installed. Records
are still converted to NormalizedDPOPair shape but no ops run.
"""
DEFAULT_RECIPE = (
Path(__file__).parent.parent / "recipes" / "replaysim" / "default.yaml"
)
def __init__(
self,
recipe_path: str | os.PathLike[str] | None = None,
*,
skip_dj: bool = False,
) -> None:
self.recipe_path = (
Path(recipe_path) if recipe_path is not None else self.DEFAULT_RECIPE
)
self.skip_dj = skip_dj
if not skip_dj:
try:
import data_juicer # type: ignore[import-not-found] # noqa: F401
except ImportError as e:
raise RuntimeError(
"DJNormalizer requires data-juicer. Install with "
"`pip install -e .[replaysim]` or pass skip_dj=True "
"for a passthrough. Got: " + repr(e)
)
if not self.skip_dj and not self.recipe_path.exists():
raise FileNotFoundError(
f"Recipe not found: {self.recipe_path}. Either pass an "
f"explicit recipe_path or add the default recipe at this "
f"location."
)
def normalize(
self,
pairs: Iterable[DPOPair | dict[str, Any]],
) -> list[NormalizedDPOPair]:
"""Run the full normalization op-graph on a batch of DPO pairs.
Args:
pairs: iterable of DPOPair (output of extract_dpo_pairs) or
dict-shaped equivalents.
Returns:
list of NormalizedDPOPair, possibly shorter than input (filter
ops can drop records).
"""
records = [_dpo_pair_to_dj_record(p) for p in pairs]
if self.skip_dj:
for rec in records:
rec["__dj_meta__"] = {"skipped": True}
return [_dj_record_to_normalized(r) for r in records]
# Real path: write to temp JSONL, hand to data-juicer's Executor,
# read back. data-juicer's CLI contract is file-in / file-out.
from data_juicer.config import init_configs # type: ignore[import-not-found]
from data_juicer.core import DefaultExecutor # type: ignore[import-not-found]
with tempfile.TemporaryDirectory() as td:
input_path = Path(td) / "input.jsonl"
output_path = Path(td) / "output.jsonl"
with input_path.open("w") as f:
for rec in records:
f.write(json.dumps(rec) + "\n")
cfg = init_configs(
args=[
"--config", str(self.recipe_path),
"--dataset_path", str(input_path),
"--export_path", str(output_path),
],
)
executor = DefaultExecutor(cfg)
executor.run()
output_records: list[dict[str, Any]] = []
with output_path.open() as f:
for line in f:
line = line.strip()
if not line:
continue
output_records.append(json.loads(line))
return [_dj_record_to_normalized(r) for r in output_records]
# ---------------------------------------------------------------------
# Convenience: replay + extract pairs + normalize, end to end.
# ---------------------------------------------------------------------
async def replay_and_normalize_trace(
*,
states: Any,
teachers: Any = None,
agreement_threshold: int = 2,
max_total_usd: float = 5.0,
normalizer: DJNormalizer | None = None,
**replay_kwargs: Any,
) -> tuple[list[TeacherCallResult], list[NormalizedDPOPair]]:
"""Async convenience: replay → extract pairs → normalize, in one call.
The underlying `replay_trace` is async; this wrapper preserves that
so callers can `await` it from an async context. For sync callers
use `replay_and_normalize_trace_sync`.
Args:
states: sequence of TraceState (the frozen agentic trace)
teachers: sequence of TeacherSpec (default: framework defaults)
agreement_threshold: passed to `extract_dpo_pairs`
max_total_usd: passed to `replay_trace`
normalizer: defaults to `DJNormalizer()`. Pass
`DJNormalizer(skip_dj=True)` to bypass data-juicer.
**replay_kwargs: extra kwargs forwarded to `replay_trace`.
Returns:
Tuple of (raw teacher_actions, normalized DPO pairs).
"""
if normalizer is None:
normalizer = DJNormalizer()
if teachers is None:
teacher_actions = await replay_trace(
states=states, max_total_usd=max_total_usd, **replay_kwargs,
)
else:
teacher_actions = await replay_trace(
states=states,
teachers=teachers,
max_total_usd=max_total_usd,
**replay_kwargs,
)
# extract_dpo_pairs reads student_action from each state's
# `student_action` field, so we don't need to pass it separately.
raw_pairs = extract_dpo_pairs(
states=states,
teacher_actions=teacher_actions,
agreement_threshold=agreement_threshold,
)
normalized = normalizer.normalize(raw_pairs)
return teacher_actions, normalized
def replay_and_normalize_trace_sync(
*args: Any,
**kwargs: Any,
) -> tuple[list[TeacherCallResult], list[NormalizedDPOPair]]:
"""Sync wrapper for the async `replay_and_normalize_trace`. Convenient
for scripts and tests.
"""
return asyncio.run(replay_and_normalize_trace(*args, **kwargs))
__all__ = [
"DJNormalizer",
"NormalizedDPOPair",
"replay_and_normalize_trace",
"replay_and_normalize_trace_sync",
]