"""real_batch.py — build a real, tokenized 3-channel batch from a HF tokenizer. Used by Spike 006's smoke to generate inputs for `compose_loss` from a real chat-template-formatted conversation, NOT random ints. """ from __future__ import annotations from typing import Any import torch def build_batch( tokenizer: Any, *, device: torch.device | str = "cpu", seed: int = 42, ) -> dict[str, torch.Tensor]: """Construct a full 3-channel input batch from a real tokenizer. Returns a dict with all keys `compose_loss` may consume: input_ids, response_mask ctx_teacher_input_ids, sdpo_loss_mask dpo_chosen_input_ids, dpo_chosen_response_mask dpo_rejected_input_ids, dpo_rejected_response_mask dpo_chosen_ref_logprobs, dpo_rejected_ref_logprobs The DPO ref logprobs are dummy tensors (not from a real reference policy forward); the smoke is verifying the loss composition wires together, not the reference-policy precompute pipeline. """ torch.manual_seed(seed) # ------------------------------------------------------------------ # Conversation 1: student rollout # ------------------------------------------------------------------ student_msgs = [ {"role": "system", "content": "You are a careful coding assistant."}, {"role": "user", "content": "Write a Python function to compute the factorial of n."}, {"role": "assistant", "content": "def factorial(n):\n if n <= 1: return 1\n return n * factorial(n - 1)"}, ] student_text = tokenizer.apply_chat_template(student_msgs, tokenize=False, add_generation_prompt=False) student_enc = tokenizer(student_text, return_tensors="pt", add_special_tokens=False) input_ids = student_enc["input_ids"].to(device) # response_mask: rough heuristic — last 30% of tokens are "the response" # (good enough for a smoke; production uses chat-template offsets) T = input_ids.shape[1] response_mask = torch.zeros_like(input_ids) response_mask[:, int(T * 0.7):] = 1 # ------------------------------------------------------------------ # Conversation 2: hint-conditioned teacher context (SDPO) # ------------------------------------------------------------------ teacher_msgs = [ {"role": "system", "content": "You are a careful coding assistant."}, {"role": "user", "content": "Write a Python function to compute the factorial of n."}, {"role": "user", "content": "[HINT] Recursion overflows for n>1000. Use an iterative loop."}, {"role": "assistant", "content": "def factorial(n):\n result = 1\n for i in range(2, n + 1):\n result *= i\n return result"}, ] teacher_text = tokenizer.apply_chat_template(teacher_msgs, tokenize=False, add_generation_prompt=False) teacher_enc = tokenizer(teacher_text, return_tensors="pt", add_special_tokens=False) ctx_teacher_input_ids = teacher_enc["input_ids"].to(device) # SDPO loss mask: 1 on the post-hint assistant tokens (the "error site") T_t = ctx_teacher_input_ids.shape[1] sdpo_loss_mask = torch.zeros_like(ctx_teacher_input_ids) sdpo_loss_mask[:, int(T_t * 0.7):] = 1 # ------------------------------------------------------------------ # Conversation 3 + 4: DPO chosen / rejected pairs # ------------------------------------------------------------------ dpo_chosen_msgs = [ {"role": "system", "content": "You are a careful coding assistant."}, {"role": "user", "content": "What's the time complexity of binary search?"}, {"role": "assistant", "content": "Binary search is O(log n) because each comparison halves the search space."}, ] dpo_rejected_msgs = [ {"role": "system", "content": "You are a careful coding assistant."}, {"role": "user", "content": "What's the time complexity of binary search?"}, {"role": "assistant", "content": "It's O(n) I think, you have to look at every element."}, ] chosen_text = tokenizer.apply_chat_template(dpo_chosen_msgs, tokenize=False, add_generation_prompt=False) rejected_text = tokenizer.apply_chat_template(dpo_rejected_msgs, tokenize=False, add_generation_prompt=False) # Pad both sequences to the same length so we can stack them chosen_enc = tokenizer(chosen_text, return_tensors="pt", add_special_tokens=False, padding=False) rejected_enc = tokenizer(rejected_text, return_tensors="pt", add_special_tokens=False, padding=False) pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id chosen_ids = chosen_enc["input_ids"] rejected_ids = rejected_enc["input_ids"] L = max(chosen_ids.shape[1], rejected_ids.shape[1]) def _pad(ids: torch.Tensor, length: int) -> torch.Tensor: cur = ids.shape[1] if cur >= length: return ids[:, :length] return torch.cat([ids, torch.full((1, length - cur), pad_id, dtype=ids.dtype)], dim=1) dpo_chosen_input_ids = _pad(chosen_ids, L).to(device) dpo_rejected_input_ids = _pad(rejected_ids, L).to(device) chosen_resp_mask = torch.zeros_like(dpo_chosen_input_ids) chosen_resp_mask[:, int(L * 0.6):chosen_ids.shape[1]] = 1 rejected_resp_mask = torch.zeros_like(dpo_rejected_input_ids) rejected_resp_mask[:, int(L * 0.6):rejected_ids.shape[1]] = 1 # Dummy reference-policy logprobs (in production: precomputed by data collator) dpo_chosen_ref_logprobs = torch.tensor([-30.0], device=device) dpo_rejected_ref_logprobs = torch.tensor([-35.0], device=device) return { "input_ids": input_ids, "response_mask": response_mask, "ctx_teacher_input_ids": ctx_teacher_input_ids, "sdpo_loss_mask": sdpo_loss_mask, "dpo_chosen_input_ids": dpo_chosen_input_ids, "dpo_chosen_response_mask": chosen_resp_mask, "dpo_rejected_input_ids": dpo_rejected_input_ids, "dpo_rejected_response_mask": rejected_resp_mask, "dpo_chosen_ref_logprobs": dpo_chosen_ref_logprobs, "dpo_rejected_ref_logprobs": dpo_rejected_ref_logprobs, } __all__ = ["build_batch"]