Aswini-Kumar commited on
Commit
5d27dfe
·
verified ·
1 Parent(s): def99af

Fix num_generations=2

Browse files
Files changed (1) hide show
  1. train_data_centric.py +4 -4
train_data_centric.py CHANGED
@@ -480,13 +480,13 @@ def run_grpo_training(model, tokenizer, resume_from_checkpoint=None, max_steps:
480
 
481
  grpo_config = GRPOConfig(
482
  output_dir="./data-centric-checkpoints",
483
- # WHY batch_size=1, num_generations=1:
 
 
484
  # Each generation = 1 full live episode (~100s on T4 with env).
485
- # 2 generations x 200 steps = ~11 hrs. 1 generation x 50 steps = ~1.5 hrs.
486
- # GRPO minimum requires num_generations >= 1.
487
  per_device_train_batch_size=1,
488
  gradient_accumulation_steps=2, # effective batch = 2
489
- num_generations=1,
490
  max_completion_length=30, # longest command is ~15 chars
491
  max_prompt_length=400,
492
  # WHY max_steps=50:
 
480
 
481
  grpo_config = GRPOConfig(
482
  output_dir="./data-centric-checkpoints",
483
+ # WHY batch_size=1, num_generations=2:
484
+ # GRPO REQUIRES num_generations >= 2 to compute advantages (compares outputs).
485
+ # Setting 1 raises ValueError. 2 is the minimum.
486
  # Each generation = 1 full live episode (~100s on T4 with env).
 
 
487
  per_device_train_batch_size=1,
488
  gradient_accumulation_steps=2, # effective batch = 2
489
+ num_generations=2, # minimum required by GRPO
490
  max_completion_length=30, # longest command is ~15 chars
491
  max_prompt_length=400,
492
  # WHY max_steps=50: