Reinforcement Learning
Transformers
English
post-training
distillation
agentic-coding
composer-2.5
cursor
kimi-k2
grpo
dapo
diloco
openenv
trl
verl
research
methodology
Instructions to use Codeseys/composer-replication-framework with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Codeseys/composer-replication-framework with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Codeseys/composer-replication-framework", dtype="auto") - Notebooks
- Google Colab
- Kaggle
| """DJNormalizer — data-juicer adapter for replaysim DPO output. | |
| Wraps the framework's `extract_dpo_pairs` output in a data-juicer op-graph. | |
| The op-graph runs entirely CPU-side and applies length filtering, chat- | |
| template validation, and per-conversation deduplication. Ops are loaded | |
| from a YAML recipe so users can swap normalization strategies without | |
| touching framework code. | |
| Default recipe lives at: | |
| composer_replication/recipes/replaysim/default.yaml | |
| The data-juicer dependency is optional (pulled by the `[replaysim]` extra). | |
| This file imports it lazily inside method bodies so that the package | |
| imports cleanly without it. | |
| Source-of-truth shape (from `composer_replication.teacher_replay`): | |
| DPOPair = TypedDict("DPOPair", { | |
| "state_id": str, | |
| "state_messages": list[dict], # conversation up to this step | |
| "chosen": str, # teacher-consensus action | |
| "rejected": str, # student action | |
| "n_teachers_agreeing": int, | |
| }) | |
| The normalizer does NOT require chosen_teacher / rejected_teacher fields — | |
| those don't exist in the real DPOPair shape. | |
| """ | |
| from __future__ import annotations | |
| import asyncio | |
| import json | |
| import os | |
| import tempfile | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Any, Iterable, cast | |
| from composer_replication.teacher_replay import ( | |
| DPOPair, | |
| TeacherCallResult, | |
| extract_dpo_pairs, | |
| replay_trace, | |
| ) | |
| class NormalizedDPOPair: | |
| """A DPOPair that has passed through normalization. Same data as | |
| DPOPair but reshaped into chat-messages format (matching data-juicer's | |
| native multi-turn op support) plus a metadata dict tracking which | |
| ops fired. | |
| """ | |
| state_id: str | |
| """Identifier for the trace state (turn) this pair came from.""" | |
| state_messages: list[dict[str, Any]] | |
| """The conversation context up to (and including) this step's user prompt.""" | |
| chosen_messages: list[dict[str, Any]] | |
| """The chosen completion as a chat-messages list (one assistant turn).""" | |
| rejected_messages: list[dict[str, Any]] | |
| """The rejected completion as a chat-messages list (one assistant turn).""" | |
| n_teachers_agreeing: int | |
| """How many teachers agreed on the chosen action (preserved from DPOPair).""" | |
| metadata: dict[str, Any] | |
| """Op-graph provenance: which ops fired, what they changed.""" | |
| def _dpo_pair_to_dj_record(pair: DPOPair | dict[str, Any]) -> dict[str, Any]: | |
| """Convert a DPOPair (or dict-shaped equivalent) into a data-juicer record. | |
| The record carries TWO shapes for chosen/rejected so that data-juicer ops | |
| that expect string-typed text fields (e.g. ``text_length_filter``, | |
| ``words_num_filter``, ``special_characters_filter``, | |
| ``document_deduplicator``) work alongside chat-aware ops: | |
| - ``chosen`` / ``rejected``: flat strings (drives the standard text ops | |
| that read string fields via ``text_keys``). | |
| - ``chosen_messages`` / ``rejected_messages``: chat-messages list | |
| (one assistant turn each), preserving the multi-turn-aware shape. | |
| The ``messages`` field carries the conversation context (matches | |
| data-juicer's ``messages`` convention for chat-aware filters). | |
| """ | |
| p = cast(dict[str, Any], pair) | |
| chosen_str = p.get("chosen", "") or "" | |
| rejected_str = p.get("rejected", "") or "" | |
| return { | |
| "state_id": p.get("state_id", ""), | |
| "messages": p.get("state_messages", []), | |
| # Flat-string shape for length/word/special-char/dedup filters | |
| # that expect text_keys to point at strings. | |
| "chosen": chosen_str, | |
| "rejected": rejected_str, | |
| # Chat-messages shape for chat-aware ops and the NormalizedDPOPair | |
| # round-trip. | |
| "chosen_messages": [{"role": "assistant", "content": chosen_str}], | |
| "rejected_messages": [{"role": "assistant", "content": rejected_str}], | |
| "n_teachers_agreeing": p.get("n_teachers_agreeing", 0), | |
| } | |
| def _dj_record_to_normalized(rec: dict[str, Any]) -> NormalizedDPOPair: | |
| """Inverse — convert a data-juicer record back to NormalizedDPOPair. | |
| Tolerates records that only carry one of the two shapes: | |
| - If ``chosen_messages``/``rejected_messages`` are present, use them | |
| directly. | |
| - Otherwise wrap the flat-string ``chosen``/``rejected`` fields into | |
| a single-assistant-turn messages list. This handles the case where | |
| a data-juicer op rewrites the string field but doesn't touch the | |
| messages field. | |
| """ | |
| def _to_messages(val: Any, fallback_str: Any) -> list[dict[str, Any]]: | |
| if isinstance(val, list) and val: | |
| return val # already chat-messages shape | |
| if isinstance(fallback_str, str) and fallback_str: | |
| return [{"role": "assistant", "content": fallback_str}] | |
| if isinstance(fallback_str, list): | |
| # Edge case: someone put the messages list in the flat field. | |
| return fallback_str | |
| return [] | |
| chosen_messages = _to_messages( | |
| rec.get("chosen_messages"), rec.get("chosen", "") | |
| ) | |
| rejected_messages = _to_messages( | |
| rec.get("rejected_messages"), rec.get("rejected", "") | |
| ) | |
| return NormalizedDPOPair( | |
| state_id=rec.get("state_id", ""), | |
| state_messages=rec.get("messages", []), | |
| chosen_messages=chosen_messages, | |
| rejected_messages=rejected_messages, | |
| n_teachers_agreeing=rec.get("n_teachers_agreeing", 0), | |
| metadata=rec.get("__dj_meta__", {}), | |
| ) | |
| class DJNormalizer: | |
| """data-juicer-backed normalizer for DPO pairs. | |
| Args: | |
| recipe_path: path to a data-juicer YAML recipe. If None, uses the | |
| framework's default recipe (length filter + chat-template | |
| validation + per-conversation dedup). | |
| skip_dj: if True, the normalizer becomes a passthrough — useful | |
| for test environments without data-juicer installed. Records | |
| are still converted to NormalizedDPOPair shape but no ops run. | |
| """ | |
| DEFAULT_RECIPE = ( | |
| Path(__file__).parent.parent / "recipes" / "replaysim" / "default.yaml" | |
| ) | |
| def __init__( | |
| self, | |
| recipe_path: str | os.PathLike[str] | None = None, | |
| *, | |
| skip_dj: bool = False, | |
| ) -> None: | |
| self.recipe_path = ( | |
| Path(recipe_path) if recipe_path is not None else self.DEFAULT_RECIPE | |
| ) | |
| self.skip_dj = skip_dj | |
| if not skip_dj: | |
| try: | |
| import data_juicer # type: ignore[import-not-found] # noqa: F401 | |
| except ImportError as e: | |
| raise RuntimeError( | |
| "DJNormalizer requires data-juicer. Install with " | |
| "`pip install -e .[replaysim]` or pass skip_dj=True " | |
| "for a passthrough. Got: " + repr(e) | |
| ) | |
| if not self.skip_dj and not self.recipe_path.exists(): | |
| raise FileNotFoundError( | |
| f"Recipe not found: {self.recipe_path}. Either pass an " | |
| f"explicit recipe_path or add the default recipe at this " | |
| f"location." | |
| ) | |
| def normalize( | |
| self, | |
| pairs: Iterable[DPOPair | dict[str, Any]], | |
| ) -> list[NormalizedDPOPair]: | |
| """Run the full normalization op-graph on a batch of DPO pairs. | |
| Args: | |
| pairs: iterable of DPOPair (output of extract_dpo_pairs) or | |
| dict-shaped equivalents. | |
| Returns: | |
| list of NormalizedDPOPair, possibly shorter than input (filter | |
| ops can drop records). | |
| """ | |
| records = [_dpo_pair_to_dj_record(p) for p in pairs] | |
| if self.skip_dj: | |
| for rec in records: | |
| rec["__dj_meta__"] = {"skipped": True} | |
| return [_dj_record_to_normalized(r) for r in records] | |
| # Real path: write to temp JSONL, hand to data-juicer's Executor, | |
| # read back. data-juicer's CLI contract is file-in / file-out. | |
| from data_juicer.config import init_configs # type: ignore[import-not-found] | |
| from data_juicer.core import DefaultExecutor # type: ignore[import-not-found] | |
| with tempfile.TemporaryDirectory() as td: | |
| input_path = Path(td) / "input.jsonl" | |
| output_path = Path(td) / "output.jsonl" | |
| with input_path.open("w") as f: | |
| for rec in records: | |
| f.write(json.dumps(rec) + "\n") | |
| cfg = init_configs( | |
| args=[ | |
| "--config", str(self.recipe_path), | |
| "--dataset_path", str(input_path), | |
| "--export_path", str(output_path), | |
| ], | |
| ) | |
| executor = DefaultExecutor(cfg) | |
| executor.run() | |
| output_records: list[dict[str, Any]] = [] | |
| with output_path.open() as f: | |
| for line in f: | |
| line = line.strip() | |
| if not line: | |
| continue | |
| output_records.append(json.loads(line)) | |
| return [_dj_record_to_normalized(r) for r in output_records] | |
| # --------------------------------------------------------------------- | |
| # Convenience: replay + extract pairs + normalize, end to end. | |
| # --------------------------------------------------------------------- | |
| async def replay_and_normalize_trace( | |
| *, | |
| states: Any, | |
| teachers: Any = None, | |
| agreement_threshold: int = 2, | |
| max_total_usd: float = 5.0, | |
| normalizer: DJNormalizer | None = None, | |
| **replay_kwargs: Any, | |
| ) -> tuple[list[TeacherCallResult], list[NormalizedDPOPair]]: | |
| """Async convenience: replay → extract pairs → normalize, in one call. | |
| The underlying `replay_trace` is async; this wrapper preserves that | |
| so callers can `await` it from an async context. For sync callers | |
| use `replay_and_normalize_trace_sync`. | |
| Args: | |
| states: sequence of TraceState (the frozen agentic trace) | |
| teachers: sequence of TeacherSpec (default: framework defaults) | |
| agreement_threshold: passed to `extract_dpo_pairs` | |
| max_total_usd: passed to `replay_trace` | |
| normalizer: defaults to `DJNormalizer()`. Pass | |
| `DJNormalizer(skip_dj=True)` to bypass data-juicer. | |
| **replay_kwargs: extra kwargs forwarded to `replay_trace`. | |
| Returns: | |
| Tuple of (raw teacher_actions, normalized DPO pairs). | |
| """ | |
| if normalizer is None: | |
| normalizer = DJNormalizer() | |
| if teachers is None: | |
| teacher_actions = await replay_trace( | |
| states=states, max_total_usd=max_total_usd, **replay_kwargs, | |
| ) | |
| else: | |
| teacher_actions = await replay_trace( | |
| states=states, | |
| teachers=teachers, | |
| max_total_usd=max_total_usd, | |
| **replay_kwargs, | |
| ) | |
| # extract_dpo_pairs reads student_action from each state's | |
| # `student_action` field, so we don't need to pass it separately. | |
| raw_pairs = extract_dpo_pairs( | |
| states=states, | |
| teacher_actions=teacher_actions, | |
| agreement_threshold=agreement_threshold, | |
| ) | |
| normalized = normalizer.normalize(raw_pairs) | |
| return teacher_actions, normalized | |
| def replay_and_normalize_trace_sync( | |
| *args: Any, | |
| **kwargs: Any, | |
| ) -> tuple[list[TeacherCallResult], list[NormalizedDPOPair]]: | |
| """Sync wrapper for the async `replay_and_normalize_trace`. Convenient | |
| for scripts and tests. | |
| """ | |
| return asyncio.run(replay_and_normalize_trace(*args, **kwargs)) | |
| __all__ = [ | |
| "DJNormalizer", | |
| "NormalizedDPOPair", | |
| "replay_and_normalize_trace", | |
| "replay_and_normalize_trace_sync", | |
| ] | |