Reinforcement Learning
Transformers
English
post-training
distillation
agentic-coding
composer-2.5
cursor
kimi-k2
grpo
dapo
diloco
openenv
trl
verl
research
methodology
Instructions to use Codeseys/composer-replication-framework with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Codeseys/composer-replication-framework with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Codeseys/composer-replication-framework", dtype="auto") - Notebooks
- Google Colab
- Kaggle
File size: 11,683 Bytes
b266c31 d9dd3a5 b266c31 d9dd3a5 b266c31 d9dd3a5 b266c31 d9dd3a5 b266c31 d9dd3a5 b266c31 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 | """DJNormalizer — data-juicer adapter for replaysim DPO output.
Wraps the framework's `extract_dpo_pairs` output in a data-juicer op-graph.
The op-graph runs entirely CPU-side and applies length filtering, chat-
template validation, and per-conversation deduplication. Ops are loaded
from a YAML recipe so users can swap normalization strategies without
touching framework code.
Default recipe lives at:
composer_replication/recipes/replaysim/default.yaml
The data-juicer dependency is optional (pulled by the `[replaysim]` extra).
This file imports it lazily inside method bodies so that the package
imports cleanly without it.
Source-of-truth shape (from `composer_replication.teacher_replay`):
DPOPair = TypedDict("DPOPair", {
"state_id": str,
"state_messages": list[dict], # conversation up to this step
"chosen": str, # teacher-consensus action
"rejected": str, # student action
"n_teachers_agreeing": int,
})
The normalizer does NOT require chosen_teacher / rejected_teacher fields —
those don't exist in the real DPOPair shape.
"""
from __future__ import annotations
import asyncio
import json
import os
import tempfile
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Iterable, cast
from composer_replication.teacher_replay import (
DPOPair,
TeacherCallResult,
extract_dpo_pairs,
replay_trace,
)
@dataclass
class NormalizedDPOPair:
"""A DPOPair that has passed through normalization. Same data as
DPOPair but reshaped into chat-messages format (matching data-juicer's
native multi-turn op support) plus a metadata dict tracking which
ops fired.
"""
state_id: str
"""Identifier for the trace state (turn) this pair came from."""
state_messages: list[dict[str, Any]]
"""The conversation context up to (and including) this step's user prompt."""
chosen_messages: list[dict[str, Any]]
"""The chosen completion as a chat-messages list (one assistant turn)."""
rejected_messages: list[dict[str, Any]]
"""The rejected completion as a chat-messages list (one assistant turn)."""
n_teachers_agreeing: int
"""How many teachers agreed on the chosen action (preserved from DPOPair)."""
metadata: dict[str, Any]
"""Op-graph provenance: which ops fired, what they changed."""
def _dpo_pair_to_dj_record(pair: DPOPair | dict[str, Any]) -> dict[str, Any]:
"""Convert a DPOPair (or dict-shaped equivalent) into a data-juicer record.
The record carries TWO shapes for chosen/rejected so that data-juicer ops
that expect string-typed text fields (e.g. ``text_length_filter``,
``words_num_filter``, ``special_characters_filter``,
``document_deduplicator``) work alongside chat-aware ops:
- ``chosen`` / ``rejected``: flat strings (drives the standard text ops
that read string fields via ``text_keys``).
- ``chosen_messages`` / ``rejected_messages``: chat-messages list
(one assistant turn each), preserving the multi-turn-aware shape.
The ``messages`` field carries the conversation context (matches
data-juicer's ``messages`` convention for chat-aware filters).
"""
p = cast(dict[str, Any], pair)
chosen_str = p.get("chosen", "") or ""
rejected_str = p.get("rejected", "") or ""
return {
"state_id": p.get("state_id", ""),
"messages": p.get("state_messages", []),
# Flat-string shape for length/word/special-char/dedup filters
# that expect text_keys to point at strings.
"chosen": chosen_str,
"rejected": rejected_str,
# Chat-messages shape for chat-aware ops and the NormalizedDPOPair
# round-trip.
"chosen_messages": [{"role": "assistant", "content": chosen_str}],
"rejected_messages": [{"role": "assistant", "content": rejected_str}],
"n_teachers_agreeing": p.get("n_teachers_agreeing", 0),
}
def _dj_record_to_normalized(rec: dict[str, Any]) -> NormalizedDPOPair:
"""Inverse — convert a data-juicer record back to NormalizedDPOPair.
Tolerates records that only carry one of the two shapes:
- If ``chosen_messages``/``rejected_messages`` are present, use them
directly.
- Otherwise wrap the flat-string ``chosen``/``rejected`` fields into
a single-assistant-turn messages list. This handles the case where
a data-juicer op rewrites the string field but doesn't touch the
messages field.
"""
def _to_messages(val: Any, fallback_str: Any) -> list[dict[str, Any]]:
if isinstance(val, list) and val:
return val # already chat-messages shape
if isinstance(fallback_str, str) and fallback_str:
return [{"role": "assistant", "content": fallback_str}]
if isinstance(fallback_str, list):
# Edge case: someone put the messages list in the flat field.
return fallback_str
return []
chosen_messages = _to_messages(
rec.get("chosen_messages"), rec.get("chosen", "")
)
rejected_messages = _to_messages(
rec.get("rejected_messages"), rec.get("rejected", "")
)
return NormalizedDPOPair(
state_id=rec.get("state_id", ""),
state_messages=rec.get("messages", []),
chosen_messages=chosen_messages,
rejected_messages=rejected_messages,
n_teachers_agreeing=rec.get("n_teachers_agreeing", 0),
metadata=rec.get("__dj_meta__", {}),
)
class DJNormalizer:
"""data-juicer-backed normalizer for DPO pairs.
Args:
recipe_path: path to a data-juicer YAML recipe. If None, uses the
framework's default recipe (length filter + chat-template
validation + per-conversation dedup).
skip_dj: if True, the normalizer becomes a passthrough — useful
for test environments without data-juicer installed. Records
are still converted to NormalizedDPOPair shape but no ops run.
"""
DEFAULT_RECIPE = (
Path(__file__).parent.parent / "recipes" / "replaysim" / "default.yaml"
)
def __init__(
self,
recipe_path: str | os.PathLike[str] | None = None,
*,
skip_dj: bool = False,
) -> None:
self.recipe_path = (
Path(recipe_path) if recipe_path is not None else self.DEFAULT_RECIPE
)
self.skip_dj = skip_dj
if not skip_dj:
try:
import data_juicer # type: ignore[import-not-found] # noqa: F401
except ImportError as e:
raise RuntimeError(
"DJNormalizer requires data-juicer. Install with "
"`pip install -e .[replaysim]` or pass skip_dj=True "
"for a passthrough. Got: " + repr(e)
)
if not self.skip_dj and not self.recipe_path.exists():
raise FileNotFoundError(
f"Recipe not found: {self.recipe_path}. Either pass an "
f"explicit recipe_path or add the default recipe at this "
f"location."
)
def normalize(
self,
pairs: Iterable[DPOPair | dict[str, Any]],
) -> list[NormalizedDPOPair]:
"""Run the full normalization op-graph on a batch of DPO pairs.
Args:
pairs: iterable of DPOPair (output of extract_dpo_pairs) or
dict-shaped equivalents.
Returns:
list of NormalizedDPOPair, possibly shorter than input (filter
ops can drop records).
"""
records = [_dpo_pair_to_dj_record(p) for p in pairs]
if self.skip_dj:
for rec in records:
rec["__dj_meta__"] = {"skipped": True}
return [_dj_record_to_normalized(r) for r in records]
# Real path: write to temp JSONL, hand to data-juicer's Executor,
# read back. data-juicer's CLI contract is file-in / file-out.
from data_juicer.config import init_configs # type: ignore[import-not-found]
from data_juicer.core import DefaultExecutor # type: ignore[import-not-found]
with tempfile.TemporaryDirectory() as td:
input_path = Path(td) / "input.jsonl"
output_path = Path(td) / "output.jsonl"
with input_path.open("w") as f:
for rec in records:
f.write(json.dumps(rec) + "\n")
cfg = init_configs(
args=[
"--config", str(self.recipe_path),
"--dataset_path", str(input_path),
"--export_path", str(output_path),
],
)
executor = DefaultExecutor(cfg)
executor.run()
output_records: list[dict[str, Any]] = []
with output_path.open() as f:
for line in f:
line = line.strip()
if not line:
continue
output_records.append(json.loads(line))
return [_dj_record_to_normalized(r) for r in output_records]
# ---------------------------------------------------------------------
# Convenience: replay + extract pairs + normalize, end to end.
# ---------------------------------------------------------------------
async def replay_and_normalize_trace(
*,
states: Any,
teachers: Any = None,
agreement_threshold: int = 2,
max_total_usd: float = 5.0,
normalizer: DJNormalizer | None = None,
**replay_kwargs: Any,
) -> tuple[list[TeacherCallResult], list[NormalizedDPOPair]]:
"""Async convenience: replay → extract pairs → normalize, in one call.
The underlying `replay_trace` is async; this wrapper preserves that
so callers can `await` it from an async context. For sync callers
use `replay_and_normalize_trace_sync`.
Args:
states: sequence of TraceState (the frozen agentic trace)
teachers: sequence of TeacherSpec (default: framework defaults)
agreement_threshold: passed to `extract_dpo_pairs`
max_total_usd: passed to `replay_trace`
normalizer: defaults to `DJNormalizer()`. Pass
`DJNormalizer(skip_dj=True)` to bypass data-juicer.
**replay_kwargs: extra kwargs forwarded to `replay_trace`.
Returns:
Tuple of (raw teacher_actions, normalized DPO pairs).
"""
if normalizer is None:
normalizer = DJNormalizer()
if teachers is None:
teacher_actions = await replay_trace(
states=states, max_total_usd=max_total_usd, **replay_kwargs,
)
else:
teacher_actions = await replay_trace(
states=states,
teachers=teachers,
max_total_usd=max_total_usd,
**replay_kwargs,
)
# extract_dpo_pairs reads student_action from each state's
# `student_action` field, so we don't need to pass it separately.
raw_pairs = extract_dpo_pairs(
states=states,
teacher_actions=teacher_actions,
agreement_threshold=agreement_threshold,
)
normalized = normalizer.normalize(raw_pairs)
return teacher_actions, normalized
def replay_and_normalize_trace_sync(
*args: Any,
**kwargs: Any,
) -> tuple[list[TeacherCallResult], list[NormalizedDPOPair]]:
"""Sync wrapper for the async `replay_and_normalize_trace`. Convenient
for scripts and tests.
"""
return asyncio.run(replay_and_normalize_trace(*args, **kwargs))
__all__ = [
"DJNormalizer",
"NormalizedDPOPair",
"replay_and_normalize_trace",
"replay_and_normalize_trace_sync",
]
|