Spaces:
Configuration error
Configuration error
| from __future__ import annotations | |
| import os | |
| import sys | |
| import json | |
| import argparse | |
| from pathlib import Path | |
| import requests | |
| import torch | |
| from datasets import Dataset | |
| from trl import GRPOConfig, GRPOTrainer | |
| from unsloth import FastLanguageModel, PatchFastRL | |
| sys.path.insert(0, str(Path(__file__).resolve().parent)) | |
| from agent_prompt import SYSTEM_PROMPT, get_agent_prompt | |
| PatchFastRL("GRPO", FastLanguageModel) | |
| # --- Configuration --- | |
| MODEL_NAME = os.getenv("MODEL_NAME", "meta-llama/Llama-3.2-3B-Instruct") | |
| ENV_URL = os.getenv("ENV_URL", "http://localhost:8000") | |
| OUTPUT_DIR = os.getenv("OUTPUT_DIR", "outputs/commitguard-llama-3b") | |
| WANDB_PROJECT = os.getenv("WANDB_PROJECT", "commitguard") | |
| # --- Reward: one reset + verdict per completion --- | |
| def get_reward_from_env(prompts, completions, **kwargs) -> list[float]: | |
| rewards = [] | |
| for prompt, completion in zip(prompts, completions): | |
| try: | |
| # Reset to get a fresh episode | |
| r = requests.post(f"{ENV_URL}/reset", json={}, timeout=10) | |
| if r.status_code != 200: | |
| rewards.append(-0.5) | |
| continue | |
| # Send the model's completion as the action | |
| text = completion[-1]["content"] if isinstance(completion, list) else str(completion) | |
| r = requests.post(f"{ENV_URL}/step", json={"action": text}, timeout=10) | |
| if r.status_code == 200: | |
| rewards.append(float(r.json().get("reward", 0.0))) | |
| else: | |
| rewards.append(-0.5) | |
| except Exception: | |
| rewards.append(-1.0) | |
| return rewards | |
| def build_dataset(n_samples: int) -> Dataset: | |
| print(f"Fetching {n_samples} training prompts from {ENV_URL}...") | |
| samples = [] | |
| for i in range(n_samples): | |
| try: | |
| r = requests.post(f"{ENV_URL}/reset", json={}, timeout=10) | |
| if r.status_code != 200: | |
| continue | |
| obs = r.json()["observation"] | |
| user_msg = get_agent_prompt( | |
| obs["diff"], obs["available_files"], obs.get("step_idx", 0) | |
| ) | |
| samples.append({ | |
| "prompt": [ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": user_msg}, | |
| ], | |
| }) | |
| except Exception: | |
| continue | |
| if (i + 1) % 50 == 0: | |
| print(f" fetched {i + 1}/{n_samples}") | |
| print(f"Built dataset with {len(samples)} samples.") | |
| return Dataset.from_list(samples) | |
| def main(): | |
| ap = argparse.ArgumentParser() | |
| ap.add_argument("--samples", type=int, default=200) | |
| ap.add_argument("--max-steps", type=int, default=300) | |
| ap.add_argument("--save-steps", type=int, default=50) | |
| ap.add_argument("--num-generations", type=int, default=4) | |
| ap.add_argument("--batch-size", type=int, default=1) | |
| ap.add_argument("--grad-accum", type=int, default=4) | |
| ap.add_argument("--lr", type=float, default=5e-6) | |
| ap.add_argument("--no-wandb", action="store_true") | |
| args = ap.parse_args() | |
| # 1. Load Model | |
| print(f"Loading {MODEL_NAME} with Unsloth 4-bit...") | |
| model, tokenizer = FastLanguageModel.from_pretrained( | |
| model_name=MODEL_NAME, | |
| max_seq_length=2048, | |
| load_in_4bit=True, | |
| fast_inference=True, | |
| max_lora_rank=16, | |
| ) | |
| model = FastLanguageModel.get_peft_model( | |
| model, | |
| r=8, | |
| target_modules=[ | |
| "q_proj", "k_proj", "v_proj", "o_proj", | |
| "gate_proj", "up_proj", "down_proj", | |
| ], | |
| lora_alpha=16, | |
| lora_dropout=0, | |
| bias="none", | |
| use_gradient_checkpointing="unsloth", | |
| random_state=3407, | |
| ) | |
| # 2. Build dataset from live env | |
| dataset = build_dataset(args.samples) | |
| # 3. GRPO config | |
| training_args = GRPOConfig( | |
| output_dir=OUTPUT_DIR, | |
| num_generations=args.num_generations, | |
| max_completion_length=512, | |
| per_device_train_batch_size=args.batch_size, | |
| gradient_accumulation_steps=args.grad_accum, | |
| learning_rate=args.lr, | |
| logging_steps=1, | |
| save_steps=args.save_steps, | |
| max_steps=args.max_steps, | |
| report_to="none" if args.no_wandb else "wandb", | |
| bf16=torch.cuda.is_bf16_supported(), | |
| fp16=not torch.cuda.is_bf16_supported(), | |
| ) | |
| # 4. Train | |
| trainer = GRPOTrainer( | |
| model=model, | |
| processing_class=tokenizer, | |
| reward_funcs=[get_reward_from_env], | |
| args=training_args, | |
| train_dataset=dataset, | |
| ) | |
| print("Starting GRPO training...") | |
| trainer.train() | |
| # 5. Save | |
| final_dir = f"{OUTPUT_DIR}/final" | |
| model.save_pretrained_merged(final_dir, tokenizer, save_method="lora") | |
| print(f"Training complete. LoRA adapter saved to {final_dir}") | |
| if __name__ == "__main__": | |
| main() | |