File size: 6,177 Bytes
ac05fbf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
"""real_batch.py — build a real, tokenized 3-channel batch from a HF tokenizer.

Used by Spike 006's smoke to generate inputs for `compose_loss` from a real
chat-template-formatted conversation, NOT random ints.
"""
from __future__ import annotations

from typing import Any

import torch


def build_batch(
    tokenizer: Any,
    *,
    device: torch.device | str = "cpu",
    seed: int = 42,
) -> dict[str, torch.Tensor]:
    """Construct a full 3-channel input batch from a real tokenizer.

    Returns a dict with all keys `compose_loss` may consume:
        input_ids, response_mask
        ctx_teacher_input_ids, sdpo_loss_mask
        dpo_chosen_input_ids, dpo_chosen_response_mask
        dpo_rejected_input_ids, dpo_rejected_response_mask
        dpo_chosen_ref_logprobs, dpo_rejected_ref_logprobs

    The DPO ref logprobs are dummy tensors (not from a real reference policy
    forward); the smoke is verifying the loss composition wires together,
    not the reference-policy precompute pipeline.
    """
    torch.manual_seed(seed)

    # ------------------------------------------------------------------
    # Conversation 1: student rollout
    # ------------------------------------------------------------------
    student_msgs = [
        {"role": "system", "content": "You are a careful coding assistant."},
        {"role": "user", "content": "Write a Python function to compute the factorial of n."},
        {"role": "assistant", "content": "def factorial(n):\n    if n <= 1: return 1\n    return n * factorial(n - 1)"},
    ]
    student_text = tokenizer.apply_chat_template(student_msgs, tokenize=False, add_generation_prompt=False)
    student_enc = tokenizer(student_text, return_tensors="pt", add_special_tokens=False)
    input_ids = student_enc["input_ids"].to(device)

    # response_mask: rough heuristic — last 30% of tokens are "the response"
    # (good enough for a smoke; production uses chat-template offsets)
    T = input_ids.shape[1]
    response_mask = torch.zeros_like(input_ids)
    response_mask[:, int(T * 0.7):] = 1

    # ------------------------------------------------------------------
    # Conversation 2: hint-conditioned teacher context (SDPO)
    # ------------------------------------------------------------------
    teacher_msgs = [
        {"role": "system", "content": "You are a careful coding assistant."},
        {"role": "user", "content": "Write a Python function to compute the factorial of n."},
        {"role": "user", "content": "[HINT] Recursion overflows for n>1000. Use an iterative loop."},
        {"role": "assistant", "content": "def factorial(n):\n    result = 1\n    for i in range(2, n + 1):\n        result *= i\n    return result"},
    ]
    teacher_text = tokenizer.apply_chat_template(teacher_msgs, tokenize=False, add_generation_prompt=False)
    teacher_enc = tokenizer(teacher_text, return_tensors="pt", add_special_tokens=False)
    ctx_teacher_input_ids = teacher_enc["input_ids"].to(device)

    # SDPO loss mask: 1 on the post-hint assistant tokens (the "error site")
    T_t = ctx_teacher_input_ids.shape[1]
    sdpo_loss_mask = torch.zeros_like(ctx_teacher_input_ids)
    sdpo_loss_mask[:, int(T_t * 0.7):] = 1

    # ------------------------------------------------------------------
    # Conversation 3 + 4: DPO chosen / rejected pairs
    # ------------------------------------------------------------------
    dpo_chosen_msgs = [
        {"role": "system", "content": "You are a careful coding assistant."},
        {"role": "user", "content": "What's the time complexity of binary search?"},
        {"role": "assistant", "content": "Binary search is O(log n) because each comparison halves the search space."},
    ]
    dpo_rejected_msgs = [
        {"role": "system", "content": "You are a careful coding assistant."},
        {"role": "user", "content": "What's the time complexity of binary search?"},
        {"role": "assistant", "content": "It's O(n) I think, you have to look at every element."},
    ]
    chosen_text = tokenizer.apply_chat_template(dpo_chosen_msgs, tokenize=False, add_generation_prompt=False)
    rejected_text = tokenizer.apply_chat_template(dpo_rejected_msgs, tokenize=False, add_generation_prompt=False)

    # Pad both sequences to the same length so we can stack them
    chosen_enc = tokenizer(chosen_text, return_tensors="pt", add_special_tokens=False, padding=False)
    rejected_enc = tokenizer(rejected_text, return_tensors="pt", add_special_tokens=False, padding=False)

    pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id

    chosen_ids = chosen_enc["input_ids"]
    rejected_ids = rejected_enc["input_ids"]
    L = max(chosen_ids.shape[1], rejected_ids.shape[1])

    def _pad(ids: torch.Tensor, length: int) -> torch.Tensor:
        cur = ids.shape[1]
        if cur >= length:
            return ids[:, :length]
        return torch.cat([ids, torch.full((1, length - cur), pad_id, dtype=ids.dtype)], dim=1)

    dpo_chosen_input_ids = _pad(chosen_ids, L).to(device)
    dpo_rejected_input_ids = _pad(rejected_ids, L).to(device)

    chosen_resp_mask = torch.zeros_like(dpo_chosen_input_ids)
    chosen_resp_mask[:, int(L * 0.6):chosen_ids.shape[1]] = 1
    rejected_resp_mask = torch.zeros_like(dpo_rejected_input_ids)
    rejected_resp_mask[:, int(L * 0.6):rejected_ids.shape[1]] = 1

    # Dummy reference-policy logprobs (in production: precomputed by data collator)
    dpo_chosen_ref_logprobs = torch.tensor([-30.0], device=device)
    dpo_rejected_ref_logprobs = torch.tensor([-35.0], device=device)

    return {
        "input_ids": input_ids,
        "response_mask": response_mask,
        "ctx_teacher_input_ids": ctx_teacher_input_ids,
        "sdpo_loss_mask": sdpo_loss_mask,
        "dpo_chosen_input_ids": dpo_chosen_input_ids,
        "dpo_chosen_response_mask": chosen_resp_mask,
        "dpo_rejected_input_ids": dpo_rejected_input_ids,
        "dpo_rejected_response_mask": rejected_resp_mask,
        "dpo_chosen_ref_logprobs": dpo_chosen_ref_logprobs,
        "dpo_rejected_ref_logprobs": dpo_rejected_ref_logprobs,
    }


__all__ = ["build_batch"]