File size: 11,683 Bytes
b266c31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d9dd3a5
 
 
 
 
 
 
 
 
 
 
 
 
 
b266c31
 
d9dd3a5
 
b266c31
 
 
d9dd3a5
 
 
 
 
 
 
 
b266c31
 
 
 
 
d9dd3a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b266c31
 
 
d9dd3a5
 
b266c31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
"""DJNormalizer — data-juicer adapter for replaysim DPO output.

Wraps the framework's `extract_dpo_pairs` output in a data-juicer op-graph.
The op-graph runs entirely CPU-side and applies length filtering, chat-
template validation, and per-conversation deduplication. Ops are loaded
from a YAML recipe so users can swap normalization strategies without
touching framework code.

Default recipe lives at:
    composer_replication/recipes/replaysim/default.yaml

The data-juicer dependency is optional (pulled by the `[replaysim]` extra).
This file imports it lazily inside method bodies so that the package
imports cleanly without it.

Source-of-truth shape (from `composer_replication.teacher_replay`):

    DPOPair = TypedDict("DPOPair", {
        "state_id":           str,
        "state_messages":     list[dict],   # conversation up to this step
        "chosen":             str,          # teacher-consensus action
        "rejected":           str,          # student action
        "n_teachers_agreeing": int,
    })

The normalizer does NOT require chosen_teacher / rejected_teacher fields —
those don't exist in the real DPOPair shape.
"""
from __future__ import annotations

import asyncio
import json
import os
import tempfile
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Iterable, cast

from composer_replication.teacher_replay import (
    DPOPair,
    TeacherCallResult,
    extract_dpo_pairs,
    replay_trace,
)


@dataclass
class NormalizedDPOPair:
    """A DPOPair that has passed through normalization. Same data as
    DPOPair but reshaped into chat-messages format (matching data-juicer's
    native multi-turn op support) plus a metadata dict tracking which
    ops fired.
    """
    state_id: str
    """Identifier for the trace state (turn) this pair came from."""

    state_messages: list[dict[str, Any]]
    """The conversation context up to (and including) this step's user prompt."""

    chosen_messages: list[dict[str, Any]]
    """The chosen completion as a chat-messages list (one assistant turn)."""

    rejected_messages: list[dict[str, Any]]
    """The rejected completion as a chat-messages list (one assistant turn)."""

    n_teachers_agreeing: int
    """How many teachers agreed on the chosen action (preserved from DPOPair)."""

    metadata: dict[str, Any]
    """Op-graph provenance: which ops fired, what they changed."""


def _dpo_pair_to_dj_record(pair: DPOPair | dict[str, Any]) -> dict[str, Any]:
    """Convert a DPOPair (or dict-shaped equivalent) into a data-juicer record.

    The record carries TWO shapes for chosen/rejected so that data-juicer ops
    that expect string-typed text fields (e.g. ``text_length_filter``,
    ``words_num_filter``, ``special_characters_filter``,
    ``document_deduplicator``) work alongside chat-aware ops:

    - ``chosen`` / ``rejected``: flat strings (drives the standard text ops
      that read string fields via ``text_keys``).
    - ``chosen_messages`` / ``rejected_messages``: chat-messages list
      (one assistant turn each), preserving the multi-turn-aware shape.

    The ``messages`` field carries the conversation context (matches
    data-juicer's ``messages`` convention for chat-aware filters).
    """
    p = cast(dict[str, Any], pair)
    chosen_str = p.get("chosen", "") or ""
    rejected_str = p.get("rejected", "") or ""
    return {
        "state_id": p.get("state_id", ""),
        "messages": p.get("state_messages", []),
        # Flat-string shape for length/word/special-char/dedup filters
        # that expect text_keys to point at strings.
        "chosen": chosen_str,
        "rejected": rejected_str,
        # Chat-messages shape for chat-aware ops and the NormalizedDPOPair
        # round-trip.
        "chosen_messages": [{"role": "assistant", "content": chosen_str}],
        "rejected_messages": [{"role": "assistant", "content": rejected_str}],
        "n_teachers_agreeing": p.get("n_teachers_agreeing", 0),
    }


def _dj_record_to_normalized(rec: dict[str, Any]) -> NormalizedDPOPair:
    """Inverse — convert a data-juicer record back to NormalizedDPOPair.

    Tolerates records that only carry one of the two shapes:

    - If ``chosen_messages``/``rejected_messages`` are present, use them
      directly.
    - Otherwise wrap the flat-string ``chosen``/``rejected`` fields into
      a single-assistant-turn messages list. This handles the case where
      a data-juicer op rewrites the string field but doesn't touch the
      messages field.
    """
    def _to_messages(val: Any, fallback_str: Any) -> list[dict[str, Any]]:
        if isinstance(val, list) and val:
            return val  # already chat-messages shape
        if isinstance(fallback_str, str) and fallback_str:
            return [{"role": "assistant", "content": fallback_str}]
        if isinstance(fallback_str, list):
            # Edge case: someone put the messages list in the flat field.
            return fallback_str
        return []

    chosen_messages = _to_messages(
        rec.get("chosen_messages"), rec.get("chosen", "")
    )
    rejected_messages = _to_messages(
        rec.get("rejected_messages"), rec.get("rejected", "")
    )
    return NormalizedDPOPair(
        state_id=rec.get("state_id", ""),
        state_messages=rec.get("messages", []),
        chosen_messages=chosen_messages,
        rejected_messages=rejected_messages,
        n_teachers_agreeing=rec.get("n_teachers_agreeing", 0),
        metadata=rec.get("__dj_meta__", {}),
    )


class DJNormalizer:
    """data-juicer-backed normalizer for DPO pairs.

    Args:
        recipe_path: path to a data-juicer YAML recipe. If None, uses the
            framework's default recipe (length filter + chat-template
            validation + per-conversation dedup).
        skip_dj: if True, the normalizer becomes a passthrough — useful
            for test environments without data-juicer installed. Records
            are still converted to NormalizedDPOPair shape but no ops run.
    """

    DEFAULT_RECIPE = (
        Path(__file__).parent.parent / "recipes" / "replaysim" / "default.yaml"
    )

    def __init__(
        self,
        recipe_path: str | os.PathLike[str] | None = None,
        *,
        skip_dj: bool = False,
    ) -> None:
        self.recipe_path = (
            Path(recipe_path) if recipe_path is not None else self.DEFAULT_RECIPE
        )
        self.skip_dj = skip_dj

        if not skip_dj:
            try:
                import data_juicer  # type: ignore[import-not-found]  # noqa: F401
            except ImportError as e:
                raise RuntimeError(
                    "DJNormalizer requires data-juicer. Install with "
                    "`pip install -e .[replaysim]` or pass skip_dj=True "
                    "for a passthrough. Got: " + repr(e)
                )

        if not self.skip_dj and not self.recipe_path.exists():
            raise FileNotFoundError(
                f"Recipe not found: {self.recipe_path}. Either pass an "
                f"explicit recipe_path or add the default recipe at this "
                f"location."
            )

    def normalize(
        self,
        pairs: Iterable[DPOPair | dict[str, Any]],
    ) -> list[NormalizedDPOPair]:
        """Run the full normalization op-graph on a batch of DPO pairs.

        Args:
            pairs: iterable of DPOPair (output of extract_dpo_pairs) or
                dict-shaped equivalents.

        Returns:
            list of NormalizedDPOPair, possibly shorter than input (filter
            ops can drop records).
        """
        records = [_dpo_pair_to_dj_record(p) for p in pairs]

        if self.skip_dj:
            for rec in records:
                rec["__dj_meta__"] = {"skipped": True}
            return [_dj_record_to_normalized(r) for r in records]

        # Real path: write to temp JSONL, hand to data-juicer's Executor,
        # read back. data-juicer's CLI contract is file-in / file-out.
        from data_juicer.config import init_configs  # type: ignore[import-not-found]
        from data_juicer.core import DefaultExecutor  # type: ignore[import-not-found]

        with tempfile.TemporaryDirectory() as td:
            input_path = Path(td) / "input.jsonl"
            output_path = Path(td) / "output.jsonl"
            with input_path.open("w") as f:
                for rec in records:
                    f.write(json.dumps(rec) + "\n")
            cfg = init_configs(
                args=[
                    "--config", str(self.recipe_path),
                    "--dataset_path", str(input_path),
                    "--export_path", str(output_path),
                ],
            )
            executor = DefaultExecutor(cfg)
            executor.run()

            output_records: list[dict[str, Any]] = []
            with output_path.open() as f:
                for line in f:
                    line = line.strip()
                    if not line:
                        continue
                    output_records.append(json.loads(line))

        return [_dj_record_to_normalized(r) for r in output_records]


# ---------------------------------------------------------------------
# Convenience: replay + extract pairs + normalize, end to end.
# ---------------------------------------------------------------------


async def replay_and_normalize_trace(
    *,
    states: Any,
    teachers: Any = None,
    agreement_threshold: int = 2,
    max_total_usd: float = 5.0,
    normalizer: DJNormalizer | None = None,
    **replay_kwargs: Any,
) -> tuple[list[TeacherCallResult], list[NormalizedDPOPair]]:
    """Async convenience: replay → extract pairs → normalize, in one call.

    The underlying `replay_trace` is async; this wrapper preserves that
    so callers can `await` it from an async context. For sync callers
    use `replay_and_normalize_trace_sync`.

    Args:
        states: sequence of TraceState (the frozen agentic trace)
        teachers: sequence of TeacherSpec (default: framework defaults)
        agreement_threshold: passed to `extract_dpo_pairs`
        max_total_usd: passed to `replay_trace`
        normalizer: defaults to `DJNormalizer()`. Pass
            `DJNormalizer(skip_dj=True)` to bypass data-juicer.
        **replay_kwargs: extra kwargs forwarded to `replay_trace`.

    Returns:
        Tuple of (raw teacher_actions, normalized DPO pairs).
    """
    if normalizer is None:
        normalizer = DJNormalizer()

    if teachers is None:
        teacher_actions = await replay_trace(
            states=states, max_total_usd=max_total_usd, **replay_kwargs,
        )
    else:
        teacher_actions = await replay_trace(
            states=states,
            teachers=teachers,
            max_total_usd=max_total_usd,
            **replay_kwargs,
        )

    # extract_dpo_pairs reads student_action from each state's
    # `student_action` field, so we don't need to pass it separately.
    raw_pairs = extract_dpo_pairs(
        states=states,
        teacher_actions=teacher_actions,
        agreement_threshold=agreement_threshold,
    )

    normalized = normalizer.normalize(raw_pairs)
    return teacher_actions, normalized


def replay_and_normalize_trace_sync(
    *args: Any,
    **kwargs: Any,
) -> tuple[list[TeacherCallResult], list[NormalizedDPOPair]]:
    """Sync wrapper for the async `replay_and_normalize_trace`. Convenient
    for scripts and tests.
    """
    return asyncio.run(replay_and_normalize_trace(*args, **kwargs))


__all__ = [
    "DJNormalizer",
    "NormalizedDPOPair",
    "replay_and_normalize_trace",
    "replay_and_normalize_trace_sync",
]