from __future__ import annotations import os import sys import json import argparse from pathlib import Path import requests import torch from datasets import Dataset from trl import GRPOConfig, GRPOTrainer from unsloth import FastLanguageModel, PatchFastRL sys.path.insert(0, str(Path(__file__).resolve().parent)) from agent_prompt import SYSTEM_PROMPT, get_agent_prompt PatchFastRL("GRPO", FastLanguageModel) # --- Configuration --- MODEL_NAME = os.getenv("MODEL_NAME", "meta-llama/Llama-3.2-3B-Instruct") ENV_URL = os.getenv("ENV_URL", "http://localhost:8000") OUTPUT_DIR = os.getenv("OUTPUT_DIR", "outputs/commitguard-llama-3b") WANDB_PROJECT = os.getenv("WANDB_PROJECT", "commitguard") # --- Reward: one reset + verdict per completion --- def get_reward_from_env(prompts, completions, **kwargs) -> list[float]: rewards = [] for prompt, completion in zip(prompts, completions): try: # Reset to get a fresh episode r = requests.post(f"{ENV_URL}/reset", json={}, timeout=10) if r.status_code != 200: rewards.append(-0.5) continue # Send the model's completion as the action text = completion[-1]["content"] if isinstance(completion, list) else str(completion) r = requests.post(f"{ENV_URL}/step", json={"action": text}, timeout=10) if r.status_code == 200: rewards.append(float(r.json().get("reward", 0.0))) else: rewards.append(-0.5) except Exception: rewards.append(-1.0) return rewards def build_dataset(n_samples: int) -> Dataset: print(f"Fetching {n_samples} training prompts from {ENV_URL}...") samples = [] for i in range(n_samples): try: r = requests.post(f"{ENV_URL}/reset", json={}, timeout=10) if r.status_code != 200: continue obs = r.json()["observation"] user_msg = get_agent_prompt( obs["diff"], obs["available_files"], obs.get("step_idx", 0) ) samples.append({ "prompt": [ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": user_msg}, ], }) except Exception: continue if (i + 1) % 50 == 0: print(f" fetched {i + 1}/{n_samples}") print(f"Built dataset with {len(samples)} samples.") return Dataset.from_list(samples) def main(): ap = argparse.ArgumentParser() ap.add_argument("--samples", type=int, default=200) ap.add_argument("--max-steps", type=int, default=300) ap.add_argument("--save-steps", type=int, default=50) ap.add_argument("--num-generations", type=int, default=4) ap.add_argument("--batch-size", type=int, default=1) ap.add_argument("--grad-accum", type=int, default=4) ap.add_argument("--lr", type=float, default=5e-6) ap.add_argument("--no-wandb", action="store_true") args = ap.parse_args() # 1. Load Model print(f"Loading {MODEL_NAME} with Unsloth 4-bit...") model, tokenizer = FastLanguageModel.from_pretrained( model_name=MODEL_NAME, max_seq_length=2048, load_in_4bit=True, fast_inference=True, max_lora_rank=16, ) model = FastLanguageModel.get_peft_model( model, r=8, target_modules=[ "q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", ], lora_alpha=16, lora_dropout=0, bias="none", use_gradient_checkpointing="unsloth", random_state=3407, ) # 2. Build dataset from live env dataset = build_dataset(args.samples) # 3. GRPO config training_args = GRPOConfig( output_dir=OUTPUT_DIR, num_generations=args.num_generations, max_completion_length=512, per_device_train_batch_size=args.batch_size, gradient_accumulation_steps=args.grad_accum, learning_rate=args.lr, logging_steps=1, save_steps=args.save_steps, max_steps=args.max_steps, report_to="none" if args.no_wandb else "wandb", bf16=torch.cuda.is_bf16_supported(), fp16=not torch.cuda.is_bf16_supported(), ) # 4. Train trainer = GRPOTrainer( model=model, processing_class=tokenizer, reward_funcs=[get_reward_from_env], args=training_args, train_dataset=dataset, ) print("Starting GRPO training...") trainer.train() # 5. Save final_dir = f"{OUTPUT_DIR}/final" model.save_pretrained_merged(final_dir, tokenizer, save_method="lora") print(f"Training complete. LoRA adapter saved to {final_dir}") if __name__ == "__main__": main()