Reinforcement Learning
Transformers
English
post-training
distillation
agentic-coding
composer-2.5
cursor
kimi-k2
grpo
dapo
diloco
openenv
trl
verl
research
methodology
Instructions to use Codeseys/composer-replication-framework with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Codeseys/composer-replication-framework with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Codeseys/composer-replication-framework", dtype="auto") - Notebooks
- Google Colab
- Kaggle
File size: 6,177 Bytes
ac05fbf | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 | """real_batch.py — build a real, tokenized 3-channel batch from a HF tokenizer.
Used by Spike 006's smoke to generate inputs for `compose_loss` from a real
chat-template-formatted conversation, NOT random ints.
"""
from __future__ import annotations
from typing import Any
import torch
def build_batch(
tokenizer: Any,
*,
device: torch.device | str = "cpu",
seed: int = 42,
) -> dict[str, torch.Tensor]:
"""Construct a full 3-channel input batch from a real tokenizer.
Returns a dict with all keys `compose_loss` may consume:
input_ids, response_mask
ctx_teacher_input_ids, sdpo_loss_mask
dpo_chosen_input_ids, dpo_chosen_response_mask
dpo_rejected_input_ids, dpo_rejected_response_mask
dpo_chosen_ref_logprobs, dpo_rejected_ref_logprobs
The DPO ref logprobs are dummy tensors (not from a real reference policy
forward); the smoke is verifying the loss composition wires together,
not the reference-policy precompute pipeline.
"""
torch.manual_seed(seed)
# ------------------------------------------------------------------
# Conversation 1: student rollout
# ------------------------------------------------------------------
student_msgs = [
{"role": "system", "content": "You are a careful coding assistant."},
{"role": "user", "content": "Write a Python function to compute the factorial of n."},
{"role": "assistant", "content": "def factorial(n):\n if n <= 1: return 1\n return n * factorial(n - 1)"},
]
student_text = tokenizer.apply_chat_template(student_msgs, tokenize=False, add_generation_prompt=False)
student_enc = tokenizer(student_text, return_tensors="pt", add_special_tokens=False)
input_ids = student_enc["input_ids"].to(device)
# response_mask: rough heuristic — last 30% of tokens are "the response"
# (good enough for a smoke; production uses chat-template offsets)
T = input_ids.shape[1]
response_mask = torch.zeros_like(input_ids)
response_mask[:, int(T * 0.7):] = 1
# ------------------------------------------------------------------
# Conversation 2: hint-conditioned teacher context (SDPO)
# ------------------------------------------------------------------
teacher_msgs = [
{"role": "system", "content": "You are a careful coding assistant."},
{"role": "user", "content": "Write a Python function to compute the factorial of n."},
{"role": "user", "content": "[HINT] Recursion overflows for n>1000. Use an iterative loop."},
{"role": "assistant", "content": "def factorial(n):\n result = 1\n for i in range(2, n + 1):\n result *= i\n return result"},
]
teacher_text = tokenizer.apply_chat_template(teacher_msgs, tokenize=False, add_generation_prompt=False)
teacher_enc = tokenizer(teacher_text, return_tensors="pt", add_special_tokens=False)
ctx_teacher_input_ids = teacher_enc["input_ids"].to(device)
# SDPO loss mask: 1 on the post-hint assistant tokens (the "error site")
T_t = ctx_teacher_input_ids.shape[1]
sdpo_loss_mask = torch.zeros_like(ctx_teacher_input_ids)
sdpo_loss_mask[:, int(T_t * 0.7):] = 1
# ------------------------------------------------------------------
# Conversation 3 + 4: DPO chosen / rejected pairs
# ------------------------------------------------------------------
dpo_chosen_msgs = [
{"role": "system", "content": "You are a careful coding assistant."},
{"role": "user", "content": "What's the time complexity of binary search?"},
{"role": "assistant", "content": "Binary search is O(log n) because each comparison halves the search space."},
]
dpo_rejected_msgs = [
{"role": "system", "content": "You are a careful coding assistant."},
{"role": "user", "content": "What's the time complexity of binary search?"},
{"role": "assistant", "content": "It's O(n) I think, you have to look at every element."},
]
chosen_text = tokenizer.apply_chat_template(dpo_chosen_msgs, tokenize=False, add_generation_prompt=False)
rejected_text = tokenizer.apply_chat_template(dpo_rejected_msgs, tokenize=False, add_generation_prompt=False)
# Pad both sequences to the same length so we can stack them
chosen_enc = tokenizer(chosen_text, return_tensors="pt", add_special_tokens=False, padding=False)
rejected_enc = tokenizer(rejected_text, return_tensors="pt", add_special_tokens=False, padding=False)
pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
chosen_ids = chosen_enc["input_ids"]
rejected_ids = rejected_enc["input_ids"]
L = max(chosen_ids.shape[1], rejected_ids.shape[1])
def _pad(ids: torch.Tensor, length: int) -> torch.Tensor:
cur = ids.shape[1]
if cur >= length:
return ids[:, :length]
return torch.cat([ids, torch.full((1, length - cur), pad_id, dtype=ids.dtype)], dim=1)
dpo_chosen_input_ids = _pad(chosen_ids, L).to(device)
dpo_rejected_input_ids = _pad(rejected_ids, L).to(device)
chosen_resp_mask = torch.zeros_like(dpo_chosen_input_ids)
chosen_resp_mask[:, int(L * 0.6):chosen_ids.shape[1]] = 1
rejected_resp_mask = torch.zeros_like(dpo_rejected_input_ids)
rejected_resp_mask[:, int(L * 0.6):rejected_ids.shape[1]] = 1
# Dummy reference-policy logprobs (in production: precomputed by data collator)
dpo_chosen_ref_logprobs = torch.tensor([-30.0], device=device)
dpo_rejected_ref_logprobs = torch.tensor([-35.0], device=device)
return {
"input_ids": input_ids,
"response_mask": response_mask,
"ctx_teacher_input_ids": ctx_teacher_input_ids,
"sdpo_loss_mask": sdpo_loss_mask,
"dpo_chosen_input_ids": dpo_chosen_input_ids,
"dpo_chosen_response_mask": chosen_resp_mask,
"dpo_rejected_input_ids": dpo_rejected_input_ids,
"dpo_rejected_response_mask": rejected_resp_mask,
"dpo_chosen_ref_logprobs": dpo_chosen_ref_logprobs,
"dpo_rejected_ref_logprobs": dpo_rejected_ref_logprobs,
}
__all__ = ["build_batch"]
|