bbkdevops's picture
download
raw
14.5 kB
"""
Code GRPO Trainer — Policy Optimization with Python Execution Feedback
โมเดลเขียนโค้ด → รัน test cases จริง → reward จากผลลัพธ์จริง
เหมือน DeepSeek-Coder-V2 แต่ปรับสำหรับ TinyMind
Reward structure:
+1.0 ผ่านทุก test case
+0.5 ผ่านบางส่วน (partial credit)
+0.2 format ถูก (มี <think> + <answer> + def)
0.0 โค้ด compile ได้แต่ logic ผิด
-0.3 SyntaxError
-0.5 ไม่มีโค้ดเลย
"""
from __future__ import annotations
import ast
import io
import json
import re
import sys
import time
import traceback
from contextlib import redirect_stdout, redirect_stderr
from dataclasses import dataclass
from pathlib import Path
from typing import Callable
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tokenizers import Tokenizer
from model.config import OmegaConfig, small_config
from model.architecture import OmegaModel
from model.reasoning import extract_thinking
from train.grpo_trainer import (
compute_group_advantages, grpo_policy_loss, grpo_collate,
GRPO_CFG,
)
CODE_SYSTEM = (
"You are an expert Python programmer. "
"Think step-by-step in <think>...</think>, "
"then provide complete correct Python in <answer>...</answer>."
)
CODE_GRPO_CFG = {
**GRPO_CFG,
"data_path": "data/filtered/code_grpo.jsonl",
"ref_checkpoint": "checkpoints/omega_best.pt",
"out_dir": "checkpoints",
"n_samples": 4,
"max_new_tokens": 512,
"temperature": 0.8,
"max_steps": 4_000,
"save_every": 400,
"timeout_sec": 5.0, # max seconds per code execution
}
# ─── Code Extraction ─────────────────────────────────────────────────────────
def extract_code(response: str) -> str:
"""ดึง Python code จาก <answer> หรือ markdown block"""
# 1. <answer>...</answer>
m = re.search(r"<answer>([\s\S]*?)</answer>", response, re.IGNORECASE)
if m:
block = m.group(1).strip()
# strip markdown fence inside answer
fence = re.search(r"```python\s*([\s\S]*?)```", block, re.IGNORECASE)
return fence.group(1).strip() if fence else block
# 2. markdown fence anywhere
fence = re.search(r"```python\s*([\s\S]*?)```", response, re.IGNORECASE)
if fence:
return fence.group(1).strip()
# 3. find def ... block
lines = response.split("\n")
code_lines: list[str] = []
in_def = False
for line in lines:
if re.match(r"^\s*def \w+", line):
in_def = True
if in_def:
code_lines.append(line)
return "\n".join(code_lines) if code_lines else ""
# ─── Execution Sandbox ────────────────────────────────────────────────────────
@dataclass
class ExecResult:
passed: int
total: int
error: str
def run_code_tests(code: str, test_cases: str, timeout: float = 5.0) -> ExecResult:
"""Execute code + test cases in a sandboxed namespace with timeout."""
if not code.strip():
return ExecResult(0, 1, "no_code")
# Parse check
try:
ast.parse(code)
except SyntaxError as e:
return ExecResult(0, 1, f"SyntaxError: {e}")
# Count assertions
total = max(1, test_cases.count("assert"))
namespace: dict = {"__builtins__": __builtins__}
stdout_buf = io.StringIO()
stderr_buf = io.StringIO()
try:
with redirect_stdout(stdout_buf), redirect_stderr(stderr_buf):
exec(compile(code, "<solution>", "exec"), namespace)
except Exception as e:
return ExecResult(0, total, f"exec_error: {e}")
# Run each assertion separately for partial credit
passed = 0
assertion_lines = [
line.strip() for line in test_cases.split("\n")
if line.strip().startswith("assert")
]
if not assertion_lines:
assertion_lines = [test_cases.strip()]
for assertion in assertion_lines:
try:
with redirect_stdout(stdout_buf), redirect_stderr(stderr_buf):
exec(compile(assertion, "<test>", "exec"), namespace)
passed += 1
except AssertionError:
pass
except Exception:
pass
return ExecResult(passed, max(len(assertion_lines), 1), "")
# ─── Code Reward Function ─────────────────────────────────────────────────────
def code_reward(
generated: str,
test_cases: str,
timeout: float = 5.0,
) -> float:
"""Composite reward: correctness + format"""
_, answer_text = extract_thinking(generated)
code = extract_code(generated)
# Format reward
has_think = bool(re.search(r"<think>[\s\S]+</think>", generated, re.IGNORECASE))
has_answer = bool(re.search(r"<answer>[\s\S]+</answer>", generated, re.IGNORECASE))
has_def = bool(re.search(r"def \w+", code))
format_score = (0.1 if has_think else 0) + (0.05 if has_answer else 0) + (0.05 if has_def else 0)
if not code:
return -0.5 + format_score
result = run_code_tests(code, test_cases, timeout)
if "SyntaxError" in result.error:
return -0.3 + format_score
if "exec_error" in result.error:
return -0.1 + format_score
ratio = result.passed / max(result.total, 1)
if ratio == 1.0:
correctness = 1.0
elif ratio >= 0.5:
correctness = 0.5 * ratio
else:
correctness = 0.1 * ratio
return correctness + format_score
# ─── Code Dataset ─────────────────────────────────────────────────────────────
class CodeDataset(Dataset):
def __init__(self, path: str, tokenizer: Tokenizer, max_prompt_len: int = 512):
self.tokenizer = tokenizer
self.max_prompt_len = max_prompt_len
self.records: list[dict] = []
with open(path, encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
rec = json.loads(line)
if rec.get("question") and rec.get("test_cases"):
self.records.append(rec)
def __len__(self) -> int:
return len(self.records)
def __getitem__(self, idx: int) -> dict:
rec = self.records[idx]
prompt = (
f"<bos><system>{CODE_SYSTEM}</system>\n"
f"<user>{rec['question']}</user>\n"
f"<assistant><think>"
)
enc = self.tokenizer.encode(prompt)
prompt_ids = enc.ids[: self.max_prompt_len]
return {
"question": rec["question"],
"test_cases": rec["test_cases"],
"prompt_ids": prompt_ids,
"level": rec.get("level", 1),
"category": rec.get("category", "unknown"),
}
def code_grpo_collate(batch: list[dict], pad_id: int = 0) -> dict:
max_len = max(len(b["prompt_ids"]) for b in batch)
padded = [b["prompt_ids"] + [pad_id]*(max_len - len(b["prompt_ids"])) for b in batch]
return {
"questions": [b["question"] for b in batch],
"test_cases": [b["test_cases"] for b in batch],
"prompt_ids": torch.tensor(padded, dtype=torch.long),
"prompt_lens": torch.tensor([len(b["prompt_ids"]) for b in batch], dtype=torch.long),
"levels": [b["level"] for b in batch],
}
# ─── Code GRPO Trainer ────────────────────────────────────────────────────────
class CodeGRPOTrainer:
def __init__(self, cfg: dict = CODE_GRPO_CFG, model_cfg: OmegaConfig | None = None):
self.cfg = cfg
self.model_cfg = model_cfg or small_config()
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.dtype = torch.bfloat16 if cfg.get("dtype") == "bfloat16" else torch.float16
self.step = 0
self.stats = {"total": 0, "pass_all": 0, "partial": 0, "fail": 0}
def setup(self):
print(f"Code GRPO | device={self.device} | G={self.cfg['n_samples']}")
tok_path = self.cfg["tokenizer_path"]
if not Path(tok_path).exists():
raise FileNotFoundError(f"Tokenizer not found: {tok_path}")
self.tokenizer = Tokenizer.from_file(tok_path)
ds = CodeDataset(
self.cfg["data_path"], self.tokenizer, self.cfg["max_prompt_len"]
)
self.loader = DataLoader(
ds, batch_size=1, shuffle=True,
collate_fn=lambda b: code_grpo_collate(b, pad_id=self.model_cfg.pad_token_id),
num_workers=0,
)
print(f"Code dataset: {len(ds):,} problems")
ref_path = self.cfg.get("ref_checkpoint")
if ref_path and Path(ref_path).exists():
ckpt = torch.load(ref_path, map_location=self.device, weights_only=False)
saved_cfg: OmegaConfig = ckpt["model_cfg"]
self.model = OmegaModel(saved_cfg).to(self.device)
self.model.load_state_dict(ckpt["model_state"])
self.model_cfg = saved_cfg
else:
self.model = OmegaModel(self.model_cfg).to(self.device)
self.optimizer = torch.optim.AdamW(
self.model.parameters(), lr=float(self.cfg.get("lr", 1e-6)), betas=(0.9, 0.95)
)
@torch.no_grad()
def _sample(self, prompt_ids: torch.Tensor) -> list[str]:
self.model.eval()
completions: list[str] = []
for _ in range(self.cfg["n_samples"]):
generated = self.model.generate(
prompt_ids.to(self.device),
max_new_tokens=self.cfg["max_new_tokens"],
temperature=self.cfg["temperature"],
top_p=0.95,
)
new_tokens = generated[0, prompt_ids.shape[1]:].tolist()
completions.append(self.tokenizer.decode(new_tokens))
return completions
def train_step(self, batch: dict) -> float:
test_cases = batch["test_cases"][0]
prompt_ids = batch["prompt_ids"][:1]
level = batch["levels"][0]
completions = self._sample(prompt_ids)
rewards = [
code_reward(c, test_cases, timeout=self.cfg["timeout_sec"])
for c in completions
]
# Stats
for r in rewards:
self.stats["total"] += 1
if r >= 0.9: self.stats["pass_all"] += 1
elif r >= 0.4: self.stats["partial"] += 1
else: self.stats["fail"] += 1
if all(r == rewards[0] for r in rewards):
return 0.0
advantages = compute_group_advantages(rewards)
self.model.train()
total_loss = 0.0
for comp_text, adv in zip(completions, advantages):
enc = self.tokenizer.encode(comp_text)
comp_ids = torch.tensor(
[enc.ids[:self.cfg["max_new_tokens"]]], dtype=torch.long
)
with torch.no_grad():
full_ids = torch.cat([prompt_ids.to(self.device), comp_ids.to(self.device)], dim=1)
out_ref = self.model(full_ids)
lp_ref = F.log_softmax(out_ref["logits"][:, :-1, :], dim=-1)
target = full_ids[:, 1:]
old_lp = lp_ref.gather(2, target.unsqueeze(-1)).squeeze(-1)
p_len = prompt_ids.shape[1]
old_seq_lp = old_lp[:, p_len - 1:p_len - 1 + comp_ids.shape[1]].mean()
adv_t = torch.tensor([adv], device=self.device, dtype=self.dtype)
with torch.amp.autocast(
device_type=self.device.type, dtype=self.dtype,
enabled=self.device.type == "cuda"
):
loss = grpo_policy_loss(
self.model,
prompt_ids.to(self.device),
comp_ids.to(self.device),
adv_t,
old_seq_lp,
clip_eps=float(self.cfg.get("clip_eps", 0.2)),
) / self.cfg["n_samples"]
loss.backward()
total_loss += loss.item() * self.cfg["n_samples"]
torch.nn.utils.clip_grad_norm_(
self.model.parameters(), float(self.cfg.get("grad_clip", 1.0))
)
self.optimizer.step()
self.optimizer.zero_grad()
return total_loss
def save(self, tag: str = "code_grpo_latest"):
path = Path(self.cfg["out_dir"]) / f"omega_{tag}.pt"
torch.save({
"step": self.step,
"model_state": self.model.state_dict(),
"model_cfg": self.model_cfg,
}, path)
print(f" Saved → {path}")
def _pass_rate(self) -> str:
t = max(self.stats["total"], 1)
return (f"pass={self.stats['pass_all']/t*100:.1f}% "
f"partial={self.stats['partial']/t*100:.1f}% "
f"fail={self.stats['fail']/t*100:.1f}%")
def train(self):
self.setup()
data_iter = iter(self.loader)
t0 = time.time()
running_loss = 0.0
print(f"Code GRPO for {self.cfg['max_steps']:,} steps\n")
while self.step < self.cfg["max_steps"]:
try:
batch = next(data_iter)
except StopIteration:
data_iter = iter(self.loader)
batch = next(data_iter)
running_loss += self.train_step(batch)
self.step += 1
if self.step % 10 == 0:
dt = time.time() - t0
print(f"step {self.step:5d} | loss {running_loss/10:.4f} | "
f"{self._pass_rate()} | {dt:.1f}s")
running_loss = 0.0
t0 = time.time()
if self.step % self.cfg["save_every"] == 0:
self.save(f"code_grpo_step{self.step}")
self.save("code_grpo_final")
print(f"\nCode GRPO done! Final: {self._pass_rate()}")
if __name__ == "__main__":
trainer = CodeGRPOTrainer()
trainer.train()

Xet Storage Details

Size:
14.5 kB
·
Xet hash:
e470a6a594dc6d3fbe2ba1d27cdaf38f09df42aa94fb3b1c02504953aa9edae9

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.