"""DJNormalizer — data-juicer adapter for replaysim DPO output. Wraps the framework's `extract_dpo_pairs` output in a data-juicer op-graph. The op-graph runs entirely CPU-side and applies length filtering, chat- template validation, and per-conversation deduplication. Ops are loaded from a YAML recipe so users can swap normalization strategies without touching framework code. Default recipe lives at: composer_replication/recipes/replaysim/default.yaml The data-juicer dependency is optional (pulled by the `[replaysim]` extra). This file imports it lazily inside method bodies so that the package imports cleanly without it. Source-of-truth shape (from `composer_replication.teacher_replay`): DPOPair = TypedDict("DPOPair", { "state_id": str, "state_messages": list[dict], # conversation up to this step "chosen": str, # teacher-consensus action "rejected": str, # student action "n_teachers_agreeing": int, }) The normalizer does NOT require chosen_teacher / rejected_teacher fields — those don't exist in the real DPOPair shape. """ from __future__ import annotations import asyncio import json import os import tempfile from dataclasses import dataclass from pathlib import Path from typing import Any, Iterable, cast from composer_replication.teacher_replay import ( DPOPair, TeacherCallResult, extract_dpo_pairs, replay_trace, ) @dataclass class NormalizedDPOPair: """A DPOPair that has passed through normalization. Same data as DPOPair but reshaped into chat-messages format (matching data-juicer's native multi-turn op support) plus a metadata dict tracking which ops fired. """ state_id: str """Identifier for the trace state (turn) this pair came from.""" state_messages: list[dict[str, Any]] """The conversation context up to (and including) this step's user prompt.""" chosen_messages: list[dict[str, Any]] """The chosen completion as a chat-messages list (one assistant turn).""" rejected_messages: list[dict[str, Any]] """The rejected completion as a chat-messages list (one assistant turn).""" n_teachers_agreeing: int """How many teachers agreed on the chosen action (preserved from DPOPair).""" metadata: dict[str, Any] """Op-graph provenance: which ops fired, what they changed.""" def _dpo_pair_to_dj_record(pair: DPOPair | dict[str, Any]) -> dict[str, Any]: """Convert a DPOPair (or dict-shaped equivalent) into a data-juicer record. The record carries TWO shapes for chosen/rejected so that data-juicer ops that expect string-typed text fields (e.g. ``text_length_filter``, ``words_num_filter``, ``special_characters_filter``, ``document_deduplicator``) work alongside chat-aware ops: - ``chosen`` / ``rejected``: flat strings (drives the standard text ops that read string fields via ``text_keys``). - ``chosen_messages`` / ``rejected_messages``: chat-messages list (one assistant turn each), preserving the multi-turn-aware shape. The ``messages`` field carries the conversation context (matches data-juicer's ``messages`` convention for chat-aware filters). """ p = cast(dict[str, Any], pair) chosen_str = p.get("chosen", "") or "" rejected_str = p.get("rejected", "") or "" return { "state_id": p.get("state_id", ""), "messages": p.get("state_messages", []), # Flat-string shape for length/word/special-char/dedup filters # that expect text_keys to point at strings. "chosen": chosen_str, "rejected": rejected_str, # Chat-messages shape for chat-aware ops and the NormalizedDPOPair # round-trip. "chosen_messages": [{"role": "assistant", "content": chosen_str}], "rejected_messages": [{"role": "assistant", "content": rejected_str}], "n_teachers_agreeing": p.get("n_teachers_agreeing", 0), } def _dj_record_to_normalized(rec: dict[str, Any]) -> NormalizedDPOPair: """Inverse — convert a data-juicer record back to NormalizedDPOPair. Tolerates records that only carry one of the two shapes: - If ``chosen_messages``/``rejected_messages`` are present, use them directly. - Otherwise wrap the flat-string ``chosen``/``rejected`` fields into a single-assistant-turn messages list. This handles the case where a data-juicer op rewrites the string field but doesn't touch the messages field. """ def _to_messages(val: Any, fallback_str: Any) -> list[dict[str, Any]]: if isinstance(val, list) and val: return val # already chat-messages shape if isinstance(fallback_str, str) and fallback_str: return [{"role": "assistant", "content": fallback_str}] if isinstance(fallback_str, list): # Edge case: someone put the messages list in the flat field. return fallback_str return [] chosen_messages = _to_messages( rec.get("chosen_messages"), rec.get("chosen", "") ) rejected_messages = _to_messages( rec.get("rejected_messages"), rec.get("rejected", "") ) return NormalizedDPOPair( state_id=rec.get("state_id", ""), state_messages=rec.get("messages", []), chosen_messages=chosen_messages, rejected_messages=rejected_messages, n_teachers_agreeing=rec.get("n_teachers_agreeing", 0), metadata=rec.get("__dj_meta__", {}), ) class DJNormalizer: """data-juicer-backed normalizer for DPO pairs. Args: recipe_path: path to a data-juicer YAML recipe. If None, uses the framework's default recipe (length filter + chat-template validation + per-conversation dedup). skip_dj: if True, the normalizer becomes a passthrough — useful for test environments without data-juicer installed. Records are still converted to NormalizedDPOPair shape but no ops run. """ DEFAULT_RECIPE = ( Path(__file__).parent.parent / "recipes" / "replaysim" / "default.yaml" ) def __init__( self, recipe_path: str | os.PathLike[str] | None = None, *, skip_dj: bool = False, ) -> None: self.recipe_path = ( Path(recipe_path) if recipe_path is not None else self.DEFAULT_RECIPE ) self.skip_dj = skip_dj if not skip_dj: try: import data_juicer # type: ignore[import-not-found] # noqa: F401 except ImportError as e: raise RuntimeError( "DJNormalizer requires data-juicer. Install with " "`pip install -e .[replaysim]` or pass skip_dj=True " "for a passthrough. Got: " + repr(e) ) if not self.skip_dj and not self.recipe_path.exists(): raise FileNotFoundError( f"Recipe not found: {self.recipe_path}. Either pass an " f"explicit recipe_path or add the default recipe at this " f"location." ) def normalize( self, pairs: Iterable[DPOPair | dict[str, Any]], ) -> list[NormalizedDPOPair]: """Run the full normalization op-graph on a batch of DPO pairs. Args: pairs: iterable of DPOPair (output of extract_dpo_pairs) or dict-shaped equivalents. Returns: list of NormalizedDPOPair, possibly shorter than input (filter ops can drop records). """ records = [_dpo_pair_to_dj_record(p) for p in pairs] if self.skip_dj: for rec in records: rec["__dj_meta__"] = {"skipped": True} return [_dj_record_to_normalized(r) for r in records] # Real path: write to temp JSONL, hand to data-juicer's Executor, # read back. data-juicer's CLI contract is file-in / file-out. from data_juicer.config import init_configs # type: ignore[import-not-found] from data_juicer.core import DefaultExecutor # type: ignore[import-not-found] with tempfile.TemporaryDirectory() as td: input_path = Path(td) / "input.jsonl" output_path = Path(td) / "output.jsonl" with input_path.open("w") as f: for rec in records: f.write(json.dumps(rec) + "\n") cfg = init_configs( args=[ "--config", str(self.recipe_path), "--dataset_path", str(input_path), "--export_path", str(output_path), ], ) executor = DefaultExecutor(cfg) executor.run() output_records: list[dict[str, Any]] = [] with output_path.open() as f: for line in f: line = line.strip() if not line: continue output_records.append(json.loads(line)) return [_dj_record_to_normalized(r) for r in output_records] # --------------------------------------------------------------------- # Convenience: replay + extract pairs + normalize, end to end. # --------------------------------------------------------------------- async def replay_and_normalize_trace( *, states: Any, teachers: Any = None, agreement_threshold: int = 2, max_total_usd: float = 5.0, normalizer: DJNormalizer | None = None, **replay_kwargs: Any, ) -> tuple[list[TeacherCallResult], list[NormalizedDPOPair]]: """Async convenience: replay → extract pairs → normalize, in one call. The underlying `replay_trace` is async; this wrapper preserves that so callers can `await` it from an async context. For sync callers use `replay_and_normalize_trace_sync`. Args: states: sequence of TraceState (the frozen agentic trace) teachers: sequence of TeacherSpec (default: framework defaults) agreement_threshold: passed to `extract_dpo_pairs` max_total_usd: passed to `replay_trace` normalizer: defaults to `DJNormalizer()`. Pass `DJNormalizer(skip_dj=True)` to bypass data-juicer. **replay_kwargs: extra kwargs forwarded to `replay_trace`. Returns: Tuple of (raw teacher_actions, normalized DPO pairs). """ if normalizer is None: normalizer = DJNormalizer() if teachers is None: teacher_actions = await replay_trace( states=states, max_total_usd=max_total_usd, **replay_kwargs, ) else: teacher_actions = await replay_trace( states=states, teachers=teachers, max_total_usd=max_total_usd, **replay_kwargs, ) # extract_dpo_pairs reads student_action from each state's # `student_action` field, so we don't need to pass it separately. raw_pairs = extract_dpo_pairs( states=states, teacher_actions=teacher_actions, agreement_threshold=agreement_threshold, ) normalized = normalizer.normalize(raw_pairs) return teacher_actions, normalized def replay_and_normalize_trace_sync( *args: Any, **kwargs: Any, ) -> tuple[list[TeacherCallResult], list[NormalizedDPOPair]]: """Sync wrapper for the async `replay_and_normalize_trace`. Convenient for scripts and tests. """ return asyncio.run(replay_and_normalize_trace(*args, **kwargs)) __all__ = [ "DJNormalizer", "NormalizedDPOPair", "replay_and_normalize_trace", "replay_and_normalize_trace_sync", ]