commitguard / scripts /train_grpo.py
Nitishkumar-ai's picture
Upload folder using huggingface_hub
e4f3d12 verified
from __future__ import annotations
import os
import sys
import json
import argparse
from pathlib import Path
import requests
import torch
from datasets import Dataset
from trl import GRPOConfig, GRPOTrainer
from unsloth import FastLanguageModel, PatchFastRL
sys.path.insert(0, str(Path(__file__).resolve().parent))
from agent_prompt import SYSTEM_PROMPT, get_agent_prompt
PatchFastRL("GRPO", FastLanguageModel)
# --- Configuration ---
MODEL_NAME = os.getenv("MODEL_NAME", "meta-llama/Llama-3.2-3B-Instruct")
ENV_URL = os.getenv("ENV_URL", "http://localhost:8000")
OUTPUT_DIR = os.getenv("OUTPUT_DIR", "outputs/commitguard-llama-3b")
WANDB_PROJECT = os.getenv("WANDB_PROJECT", "commitguard")
# --- Reward: one reset + verdict per completion ---
def get_reward_from_env(prompts, completions, **kwargs) -> list[float]:
rewards = []
for prompt, completion in zip(prompts, completions):
try:
# Reset to get a fresh episode
r = requests.post(f"{ENV_URL}/reset", json={}, timeout=10)
if r.status_code != 200:
rewards.append(-0.5)
continue
# Send the model's completion as the action
text = completion[-1]["content"] if isinstance(completion, list) else str(completion)
r = requests.post(f"{ENV_URL}/step", json={"action": text}, timeout=10)
if r.status_code == 200:
rewards.append(float(r.json().get("reward", 0.0)))
else:
rewards.append(-0.5)
except Exception:
rewards.append(-1.0)
return rewards
def build_dataset(n_samples: int) -> Dataset:
print(f"Fetching {n_samples} training prompts from {ENV_URL}...")
samples = []
for i in range(n_samples):
try:
r = requests.post(f"{ENV_URL}/reset", json={}, timeout=10)
if r.status_code != 200:
continue
obs = r.json()["observation"]
user_msg = get_agent_prompt(
obs["diff"], obs["available_files"], obs.get("step_idx", 0)
)
samples.append({
"prompt": [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_msg},
],
})
except Exception:
continue
if (i + 1) % 50 == 0:
print(f" fetched {i + 1}/{n_samples}")
print(f"Built dataset with {len(samples)} samples.")
return Dataset.from_list(samples)
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--samples", type=int, default=200)
ap.add_argument("--max-steps", type=int, default=300)
ap.add_argument("--save-steps", type=int, default=50)
ap.add_argument("--num-generations", type=int, default=4)
ap.add_argument("--batch-size", type=int, default=1)
ap.add_argument("--grad-accum", type=int, default=4)
ap.add_argument("--lr", type=float, default=5e-6)
ap.add_argument("--no-wandb", action="store_true")
args = ap.parse_args()
# 1. Load Model
print(f"Loading {MODEL_NAME} with Unsloth 4-bit...")
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=MODEL_NAME,
max_seq_length=2048,
load_in_4bit=True,
fast_inference=True,
max_lora_rank=16,
)
model = FastLanguageModel.get_peft_model(
model,
r=8,
target_modules=[
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",
],
lora_alpha=16,
lora_dropout=0,
bias="none",
use_gradient_checkpointing="unsloth",
random_state=3407,
)
# 2. Build dataset from live env
dataset = build_dataset(args.samples)
# 3. GRPO config
training_args = GRPOConfig(
output_dir=OUTPUT_DIR,
num_generations=args.num_generations,
max_completion_length=512,
per_device_train_batch_size=args.batch_size,
gradient_accumulation_steps=args.grad_accum,
learning_rate=args.lr,
logging_steps=1,
save_steps=args.save_steps,
max_steps=args.max_steps,
report_to="none" if args.no_wandb else "wandb",
bf16=torch.cuda.is_bf16_supported(),
fp16=not torch.cuda.is_bf16_supported(),
)
# 4. Train
trainer = GRPOTrainer(
model=model,
processing_class=tokenizer,
reward_funcs=[get_reward_from_env],
args=training_args,
train_dataset=dataset,
)
print("Starting GRPO training...")
trainer.train()
# 5. Save
final_dir = f"{OUTPUT_DIR}/final"
model.save_pretrained_merged(final_dir, tokenizer, save_method="lora")
print(f"Training complete. LoRA adapter saved to {final_dir}")
if __name__ == "__main__":
main()