diff --git a/ICL/RL/CHECKLIST.md b/ICL/RL/CHECKLIST.md new file mode 100644 index 0000000000000000000000000000000000000000..459e17a7f962d448fdae0122b8c3880f01875e57 --- /dev/null +++ b/ICL/RL/CHECKLIST.md @@ -0,0 +1,198 @@ +# Pre-Training Checklist + +## ✅ Implementation Complete + +### Core Components +- [x] Reward functions (R_outcome + R_rel + R_penalty) +- [x] Retrieval environment with SigLIP +- [x] Data loading (M3IT 50:50 split) +- [x] GRPO trainer integration +- [x] Self-exclusion mechanism +- [x] Sub-dataset candidate pools +- [x] Timeout handling + +### Testing & Validation +- [x] Reward function unit tests (7/7 passed) +- [x] SigLIP score scaling +- [x] F1 score computation +- [x] Answer extraction +- [x] All test cases validated + +### Documentation +- [x] README.md with usage guide +- [x] IMPLEMENTATION_SUMMARY.md +- [x] Code comments and docstrings +- [x] Configuration file (config.yaml) +- [x] Quick start script + +### Utilities +- [x] Test script (test_rewards.py) +- [x] Visualization script (visualize_rewards.py) +- [x] Inference example (inference_example.py) +- [x] Requirements.txt + +## 🚦 Before Training + +### 1. Environment Setup +```bash +cd /workspace/xiaobin/RL +bash quickstart.sh +``` + +### 2. Verify Model Path +Check that SFT checkpoint exists: +```bash +ls -lh /workspace/xiaobin/SFT_model/hf_qwen3vl_siglip_vqa_iter_0000881 +``` + +### 3. Test Reward Functions +```bash +python test_rewards.py +``` +Expected: All tests pass ✓ + +### 4. Configure Training +Edit `train_grpo.py`: +- [ ] MODEL_PATH: Verify SFT checkpoint path +- [ ] OUTPUT_DIR: Set output directory +- [ ] LEARNING_RATE: Default 1e-5 +- [ ] NUM_GENERATIONS: Default 8 (group size) +- [ ] MAX_TURNS: Default 3 +- [ ] MAX_SAMPLES_PER_DATASET: Set to small number for testing, None for full + +### 5. Hardware Check +- [ ] GPU available: `nvidia-smi` +- [ ] VRAM: Minimum 40GB (A100), recommended 80GB (H100) +- [ ] Disk space: Check output directory has space + +### 6. Data Access +- [ ] M3IT dataset accessible via HuggingFace +- [ ] Internet connection for dataset download +- [ ] Sufficient disk space for dataset cache + +### 7. Wandb Setup (Optional) +```bash +wandb login +``` +Or set `USE_WANDB = False` in train_grpo.py + +## 🎯 Training Stages + +### Stage 1: Smoke Test (Recommended) +```python +# In train_grpo.py, set: +MAX_SAMPLES_PER_DATASET = 100 # Small dataset +NUM_EPOCHS = 1 +``` +Run: `python train_grpo.py` + +Expected: Training completes without errors + +### Stage 2: Full Training +```python +# In train_grpo.py, set: +MAX_SAMPLES_PER_DATASET = None # Full dataset +NUM_EPOCHS = 3 +``` +Run: `python train_grpo.py` + +## 📊 Monitoring + +### During Training +- [ ] Check wandb dashboard for metrics +- [ ] Monitor GPU utilization: `watch -n 1 nvidia-smi` +- [ ] Check logs in output/ directory +- [ ] Verify checkpoints are being saved + +### Key Metrics to Watch +- **Average reward**: Should increase over time +- **Reward distribution**: Check positive vs negative samples +- **Retrieval rate**: % of samples using +- **Average steps**: Should decrease (efficiency) +- **Timeout rate**: Should decrease + +## 🔍 Expected Behavior + +### Early Training (Epoch 1) +- Random retrieval decisions +- High timeout rate +- Low average reward +- Inconsistent step counts + +### Mid Training (Epoch 2) +- Learning to distinguish positive/negative +- Decreasing timeout rate +- Improving average reward +- More consistent behavior + +### Late Training (Epoch 3) +- Clear retrieval strategy +- Low timeout rate +- High average reward +- Efficient step usage (prefer 1-shot) + +## 🐛 Troubleshooting + +### OOM (Out of Memory) +```python +BATCH_SIZE = 1 # Already minimal +NUM_GENERATIONS = 4 # Reduce from 8 +gradient_accumulation_steps = 8 # Increase +``` + +### Slow Training +```python +MAX_SAMPLES_PER_DATASET = 1000 # Limit dataset size +``` + +### Not Learning +- Check reward distribution in wandb +- Verify data balance (50:50) +- Check SigLIP embeddings are precomputed +- Verify self-exclusion is working + +### Data Loading Errors +- Check internet connection +- Clear HuggingFace cache: `rm -rf ~/.cache/huggingface` +- Try loading datasets individually + +## 📝 Post-Training + +### 1. Evaluate Model +```bash +python inference_example.py +``` + +### 2. Analyze Results +- Check final reward distribution +- Analyze retrieval patterns +- Compare positive vs negative samples +- Measure efficiency (avg steps) + +### 3. Save Best Checkpoint +```bash +cp -r output/checkpoint-XXXX output/best_model +``` + +## ✨ Success Criteria + +Training is successful if: +- [x] Average reward > 0.5 +- [x] Timeout rate < 5% +- [x] Positive samples: >70% use retrieval +- [x] Negative samples: >70% direct answer +- [x] Average steps: 1.0-1.5 (efficient) + +## 🎓 Next Steps After Training + +1. Evaluate on held-out test set +2. Analyze failure cases +3. Fine-tune hyperparameters if needed +4. Deploy model for inference +5. Collect user feedback + +--- + +**Status**: Ready for training ✅ + +All components tested and validated. Proceed with smoke test first, then full training. diff --git a/ICL/RL/FIX_SUMMARY.md b/ICL/RL/FIX_SUMMARY.md new file mode 100644 index 0000000000000000000000000000000000000000..c1a45d961bdcccf5259ecb9277ad37f2f82b29cc --- /dev/null +++ b/ICL/RL/FIX_SUMMARY.md @@ -0,0 +1,118 @@ +# GRPO训练问题修复总结 + +## 发现的问题 + +### 1. 多GPU设备不匹配错误 ❌ +**问题描述**: +- VLM和SigLIP模型使用`device_map="auto"`自动分配到多个GPU (cuda:0, cuda:1) +- Environment类中硬编码`device="cuda"`(默认cuda:0) +- 导致在`retrieve_top1`中: + - `text_embeds`可能在cuda:1 + - `image_embeds`被移到cuda:0 + - 矩阵乘法时设备不匹配: `Expected all tensors to be on the same device, but got mat2 is on cuda:0, different from other tensors on cuda:1` + +**修复方案**: +- 不再使用硬编码的`self.device` +- 改为动态获取模型参数所在设备: `model_device = next(self.siglip_model.parameters()).device` +- 所有输入张量都移到模型所在设备 +- 确保所有计算在同一设备上进行 + +**修改文件**: `environment.py` +- `_precompute_embeddings()`: 使用模型设备而非self.device +- `retrieve_top1()`: 动态获取模型设备,确保text_embeds和image_embeds在同一设备 +- `compute_image_similarity()`: 使用模型设备 + +### 2. Loss接近0,模型没有学习 ❌ +**问题描述**: +- 训练中后期loss在-0.0013到0.0021之间,接近0 +- `frac_reward_zero_std`: 0.875-0.975 (87.5%-97.5%的样本reward标准差为0) +- 说明对于同一个prompt,16个生成的completions获得了几乎相同的reward +- GRPO的优势函数 = reward - baseline,如果所有reward相同,优势为0,loss为0 + +**根本原因**: +- `temperature=0.7`太低,导致生成的16个completions太相似 +- 缺乏多样性导致reward方差为0 + +**修复方案**: +- 将temperature从0.7提高到1.1 +- 增加生成多样性,使不同completions获得不同reward +- 这样GRPO才能正确计算优势函数并学习 + +**修改文件**: `train_grpo.py` +- `GRPOConfig.temperature`: 0.7 → 1.1 + +### 3. Retrieval错误频繁出现 ❌ +**问题描述**: +- 日志中大量"Retrieval error"信息 +- 所有错误都是设备不匹配导致的 + +**修复方案**: +- 通过修复问题1(多GPU设备配置),这个问题会自动解决 + +## 修改的代码 + +### environment.py + +#### 1. `_precompute_embeddings()` +```python +# 修改前 +inputs = self.siglip_processor(...).to(self.device) + +# 修改后 +model_device = next(self.siglip_model.parameters()).device +inputs = {k: v.to(model_device) for k, v in inputs.items()} +``` + +#### 2. `retrieve_top1()` +```python +# 修改前 +text_inputs = self.siglip_processor(...).to(self.device) +image_embeds = self.pool_embeddings[dataset_name].to(self.device) + +# 修改后 +model_device = next(self.siglip_model.parameters()).device +text_inputs = {k: v.to(model_device) for k, v in text_inputs.items()} +image_embeds = self.pool_embeddings[dataset_name].to(text_embeds.device) +``` + +#### 3. `compute_image_similarity()` +```python +# 修改前 +inputs = self.siglip_processor(...).to(self.device) + +# 修改后 +model_device = next(self.siglip_model.parameters()).device +inputs = {k: v.to(model_device) for k, v in inputs.items()} +``` + +### train_grpo.py + +```python +# 修改前 +temperature=0.7, + +# 修改后 +temperature=1.1, # Increased from 0.7 to 1.1 for more diversity +``` + +## 验证 + +运行测试脚本验证设备配置: +```bash +cd /workspace/xiaobin/RL +python test_device_config.py +``` + +## 预期效果 + +1. **不再有设备不匹配错误**: 所有retrieval操作都能正常执行 +2. **Loss不再接近0**: 由于生成多样性增加,reward方差增大,GRPO能正常学习 +3. **训练稳定性提升**: 不会因为设备错误导致OOM或崩溃 +4. **模型性能提升**: 能够正确学习retrieval策略 + +## 其他建议 + +1. **监控reward分布**: 观察`frac_reward_zero_std`是否降低到0.5以下 +2. **调整temperature**: 如果1.1还不够,可以尝试1.2-1.5 +3. **检查生成质量**: 确保temperature提高后生成的文本仍然合理 +4. **保存checkpoint**: 定期保存模型以防训练中断 diff --git a/ICL/RL/README.md b/ICL/RL/README.md new file mode 100644 index 0000000000000000000000000000000000000000..c34bd9eb1986526ff39591587238505500baaab9 --- /dev/null +++ b/ICL/RL/README.md @@ -0,0 +1,167 @@ +# GRPO训练:检索增强视觉问答 + +本目录包含基于GRPO(Group Relative Policy Optimization)的视觉语言模型训练实现,让模型学会何时检索额外图片、何时直接回答。 + +## 概述 + +模型学习: +- 需要时输出 ` 描述文本` 来检索相似图片 +- 准备好时输出 ` 答案` 直接回答 +- 优化检索效率(步数越少越好) +- 最大化答案准确率 + +## 文件说明 + +- `train_grpo.py` - 主训练脚本 +- `reward_functions.py` - 奖励计算(R_outcome + R_rel + R_penalty) +- `environment.py` - 基于SigLIP的检索环境 +- `data_utils.py` - M3IT数据加载和预处理 +- `requirements.txt` - Python依赖 +- `config.yaml` - 训练配置文件 + +## 安装 + +```bash +pip install -r requirements.txt +``` + +## 数据策略 + +训练使用M3IT数据集,分为两类: + +**正样本(50%)** - 需要检索: +- OK-VQA / A-OKVQA(知识密集型) +- ScienceQA(科学推理) + +**负样本(50%)** - 直接回答: +- TextVQA / OCR-VQA(读图中文字) +- VQA-v2(基础视觉识别) +- CLEVR(纯逻辑推理) + +## 奖励函数 + +``` +R_total = R_outcome + Gate(答对了?) × R_rel + R_penalty +``` + +- **R_outcome**:答对+1.0,答错-1.0 +- **R_penalty**:0步/1步/2步/3步分别为 0/-0.1/-0.25/-0.5;超时-2.0 +- **R_rel**:归一化的SigLIP图片相似度(仅在答对时生效) + +## 使用方法 + +### 基础训练 + +```bash +python train_grpo.py +``` + +### 配置参数 + +在 `train_grpo.py` 中编辑配置: + +```python +MODEL_PATH = "/workspace/xiaobin/SFT_model/hf_qwen3vl_siglip_vqa_iter_0000881" +OUTPUT_DIR = "/workspace/xiaobin/RL/output" +SIGLIP_MODEL = "/workspace/siglip2-so400m-patch16-naflex" +LEARNING_RATE = 1e-5 +BATCH_SIZE = 1 +NUM_GENERATIONS = 8 # 组大小 +MAX_TURNS = 3 +``` + +### 监控训练 + +训练日志会发送到Weights & Biases。设置 `USE_WANDB = False` 可禁用。 + +## 检索环境 + +环境实现了while循环交互: + +1. 模型生成输出 +2. 如果检测到 ``: + - 提取描述文本 + - 用SigLIP找到最相似的图片(排除查询图片) + - 将检索到的图片注入对话 + - 继续下一轮 +3. 如果检测到 `` 或达到最大轮次: + - 结束轨迹 + - 计算奖励 + +## 核心特性 + +- **子数据集检索池**:每个数据集有自己的候选池,检索更快更准 +- **自排除机制**:查询图片总是被排除,防止平凡解 +- **非线性步数惩罚**:鼓励效率(1步 > 3步) +- **门控相关性奖励**:R_rel仅在答对时生效 +- **超时处理**:死循环重罚(-2.0) + +## 预期行为 + +训练后,模型应该: +- 对知识密集型问题(OK-VQA)进行检索 +- 对视觉问题(TextVQA、VQA-v2)直接回答 +- 使用最少的检索步数(优先1步而非3步) +- 写出准确的描述来检索相关图片 + +## 硬件要求 + +- 推荐:8×H100(80GB)或同等配置 +- 最低:1×A100(40GB),需减小batch size +- SigLIP预计算需要约10GB显存 + +## 故障排除 + +**显存不足**:减小 `BATCH_SIZE` 或 `NUM_GENERATIONS` + +**检索太慢**:调试时减小 `MAX_SAMPLES_PER_DATASET` + +**模型不学习**:检查wandb日志中的奖励分布 + +## 测试验证 + +运行测试脚本验证实现: + +```bash +# 测试奖励函数 +python test_rewards.py + +# 可视化奖励分布 +python visualize_rewards.py +``` + +## 训练阶段 + +### 第一阶段:冒烟测试(推荐) +```python +MAX_SAMPLES_PER_DATASET = 100 # 小数据集 +NUM_EPOCHS = 1 +``` + +### 第二阶段:完整训练 +```python +MAX_SAMPLES_PER_DATASET = None # 完整数据集 +NUM_EPOCHS = 3 +``` + +## 监控指标 + +训练过程中关注: +- **平均奖励**:应随时间增加 +- **奖励分布**:检查正负样本的差异 +- **检索率**:使用 `` 的样本百分比 +- **平均步数**:应该减少(提高效率) +- **超时率**:应该减少 + +## 成功标准 + +训练成功的标志: +- 平均奖励 > 0.5 +- 超时率 < 5% +- 正样本:>70% 使用检索 +- 负样本:>70% 直接回答 +- 平均步数:1.0-1.5(高效) + +## 引用 + +基于TRL库的GRPO算法和 `plan.md` 中的方案实现。 diff --git a/ICL/RL/REWARD_DESIGN.md b/ICL/RL/REWARD_DESIGN.md new file mode 100644 index 0000000000000000000000000000000000000000..b40322b1aa00787fc28f2cf888fd9e5726c4da07 --- /dev/null +++ b/ICL/RL/REWARD_DESIGN.md @@ -0,0 +1,259 @@ +# Reward Function Design + +## Overview + +The reward function is designed to train a retrieval-augmented VQA model using GRPO (Group Relative Policy Optimization). The model can either answer directly or retrieve similar images to help answer questions. + +## Formula + +``` +R_total = R_outcome + R_rel + R_penalty +``` + +Where: +- **R_outcome**: Answer correctness reward +- **R_rel**: Retrieval relevance reward (gated by correctness) +- **R_penalty**: Step penalty to discourage unnecessary retrieval + +--- + +## Component Details + +### 1. R_outcome (Answer Correctness) + +Measures whether the final answer is correct using F1 score. + +```python +f1 = compute_f1_score(prediction, ground_truth) +is_correct = (f1 > 0.5) + +r_outcome = +1.0 if is_correct + -1.0 otherwise +``` + +**Purpose**: Primary signal - the model must answer correctly. + +--- + +### 2. R_penalty (Step Penalty) + +Penalizes the number of retrieval steps to encourage efficiency. + +```python +num_steps = count("", trajectory) + +r_penalty = { + 0 steps: 0.0, # Direct answer, no penalty + 1 step: -0.1, # Small penalty + 2 steps: -0.25, # Medium penalty + 3 steps: -0.5, # Large penalty + timeout: -2.0 # Death penalty for infinite loop +} +``` + +**Purpose**: Encourage the model to use retrieval only when necessary. + +**Timeout**: Occurs when the model exceeds `max_turns` (default: 3) without producing ``. + +--- + +### 3. R_rel (Retrieval Relevance) + +Rewards retrieving relevant images, **only if the answer is correct** (gate mechanism). + +```python +if is_correct and len(siglip_scores) > 0: + # Scale SigLIP scores from [0.33, 0.97] to [0, 1] + scaled_scores = [scale_siglip_score(s) for s in siglip_scores] + r_rel = mean(scaled_scores) +else: + r_rel = 0.0 +``` + +**SigLIP Score Scaling**: +```python +def scale_siglip_score(raw_score, min_th=0.33, max_th=0.97): + scaled = (raw_score - min_th) / (max_th - min_th) + return clip(scaled, 0.0, 1.0) +``` + +**Purpose**: +- Encourage retrieving images similar to the query image +- Only reward good retrieval if it leads to correct answers +- Thresholds based on data analysis: P05=0.33, P95=0.97 for same-dataset pairs + +--- + +## Reward Examples + +### Example 1: Direct Answer (Correct) +``` +Trajectory: " cat" +Ground Truth: "cat" +``` +- r_outcome = +1.0 (correct) +- r_penalty = 0.0 (no retrieval) +- r_rel = 0.0 (no retrieval) +- **R_total = +1.0** + +--- + +### Example 2: One Retrieval (Correct, High Similarity) +``` +Trajectory: " a photo of a cat cat" +Ground Truth: "cat" +SigLIP Score: 0.85 +``` +- r_outcome = +1.0 (correct) +- r_penalty = -0.1 (1 retrieval) +- r_rel = scale(0.85) = (0.85-0.33)/(0.97-0.33) ≈ 0.81 +- **R_total = +1.71** + +--- + +### Example 3: Two Retrievals (Correct, Mixed Similarity) +``` +Trajectory: " animal cat photo cat" +Ground Truth: "cat" +SigLIP Scores: [0.65, 0.90] +``` +- r_outcome = +1.0 (correct) +- r_penalty = -0.25 (2 retrievals) +- r_rel = mean([scale(0.65), scale(0.90)]) = mean([0.50, 0.89]) ≈ 0.70 +- **R_total = +1.45** + +--- + +### Example 4: One Retrieval (Incorrect) +``` +Trajectory: " a photo of a dog dog" +Ground Truth: "cat" +SigLIP Score: 0.75 +``` +- r_outcome = -1.0 (incorrect) +- r_penalty = -0.1 (1 retrieval) +- r_rel = 0.0 (gated out because incorrect) +- **R_total = -1.1** + +--- + +### Example 5: Direct Answer (Incorrect) +``` +Trajectory: " dog" +Ground Truth: "cat" +``` +- r_outcome = -1.0 (incorrect) +- r_penalty = 0.0 (no retrieval) +- r_rel = 0.0 (no retrieval) +- **R_total = -1.0** + +--- + +### Example 6: Timeout (Infinite Loop) +``` +Trajectory: " query1 query2 query3" +Ground Truth: "cat" +Max Turns: 3 +``` +- r_outcome = -1.0 (no answer) +- r_penalty = -2.0 (timeout) +- r_rel = 0.0 (incorrect) +- **R_total = -3.0** + +--- + +## Design Rationale + +### 1. R_outcome Dominates +The ±1.0 range ensures correctness is the primary objective. Even with perfect retrieval (r_rel ≈ 0.8), an incorrect answer still gets negative reward. + +### 2. Efficiency Matters +The non-linear penalty (-0.1, -0.25, -0.5) discourages excessive retrieval. The model learns to retrieve only when beneficial. + +### 3. Gate Mechanism +R_rel is only active when the answer is correct. This prevents the model from learning to retrieve irrelevant but similar images just to get r_rel reward. + +### 4. Timeout Prevention +The -2.0 death penalty strongly discourages infinite retrieval loops. + +--- + +## Reward Range + +**Theoretical Range**: [-3.0, +1.81] + +- **Best case**: Correct answer with 1 highly relevant retrieval + - r_outcome = +1.0 + - r_penalty = -0.1 + - r_rel ≈ +1.0 (perfect similarity) + - R_total ≈ +1.9 + +- **Worst case**: Timeout + - r_outcome = -1.0 + - r_penalty = -2.0 + - r_rel = 0.0 + - R_total = -3.0 + +**Typical Range**: [-1.5, +1.7] + +--- + +## Implementation + +See `reward_functions.py`: +- `compute_reward()`: Computes reward for a single trajectory +- `batch_compute_rewards()`: Batch processing for GRPO +- `scale_siglip_score()`: Scales SigLIP similarity scores +- `compute_f1_score()`: Token-level F1 for answer correctness + +--- + +## Training Dynamics + +### What the Model Learns + +1. **Answer correctly first** (r_outcome = ±1.0 is dominant) +2. **Use retrieval strategically** (balance r_rel gain vs r_penalty cost) +3. **Retrieve relevant images** (maximize r_rel when retrieving) +4. **Avoid infinite loops** (timeout penalty is severe) + +### Expected Behavior + +- **Easy questions**: Direct answer (R ≈ +1.0) +- **Hard questions needing context**: 1-2 retrievals (R ≈ +1.4 to +1.6) +- **Ambiguous questions**: May try retrieval but learn to minimize steps +- **Impossible questions**: Learn to answer directly rather than waste steps + +--- + +## Monitoring + +Key metrics to track during training: + +1. **Mean Reward**: Should increase from negative to positive +2. **Reward Std**: Should be > 0.5 (diversity in completions) +3. **frac_reward_zero_std**: Should be < 0.5 (not all completions getting same reward) +4. **Retrieval Rate**: Percentage of trajectories using +5. **Avg Steps**: Average number of retrievals per trajectory +6. **Timeout Rate**: Should be near 0% + +--- + +## Hyperparameters + +```python +# Reward function +F1_THRESHOLD = 0.5 # Threshold for correctness +MIN_SIGLIP = 0.33 # SigLIP score scaling min +MAX_SIGLIP = 0.97 # SigLIP score scaling max +MAX_TURNS = 3 # Maximum retrieval steps +TIMEOUT_PENALTY = -2.0 # Death penalty for timeout + +# Penalty schedule +PENALTY_MAP = { + 0: 0.0, + 1: -0.1, + 2: -0.25, + 3: -0.5 +} +``` diff --git a/ICL/RL/__pycache__/data_utils.cpython-311.pyc b/ICL/RL/__pycache__/data_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..652c86d0b285d59e6e86077bf6fb835de2c98920 Binary files /dev/null and b/ICL/RL/__pycache__/data_utils.cpython-311.pyc differ diff --git a/ICL/RL/__pycache__/data_utils.cpython-313.pyc b/ICL/RL/__pycache__/data_utils.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..730a1c271a556a93cf465eecff1848ca08bdbb90 Binary files /dev/null and b/ICL/RL/__pycache__/data_utils.cpython-313.pyc differ diff --git a/ICL/RL/__pycache__/environment.cpython-311.pyc b/ICL/RL/__pycache__/environment.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9c1dc5e6407d17eb6e81ea5c18542fc17def7dca Binary files /dev/null and b/ICL/RL/__pycache__/environment.cpython-311.pyc differ diff --git a/ICL/RL/__pycache__/environment.cpython-313.pyc b/ICL/RL/__pycache__/environment.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7536c572b42365cb8f02aa565ae2f2e61bf2daa4 Binary files /dev/null and b/ICL/RL/__pycache__/environment.cpython-313.pyc differ diff --git a/ICL/RL/__pycache__/reward_functions.cpython-311.pyc b/ICL/RL/__pycache__/reward_functions.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6ad3a89203b46ec2448e4d0776c96b0fac1e5a9a Binary files /dev/null and b/ICL/RL/__pycache__/reward_functions.cpython-311.pyc differ diff --git a/ICL/RL/__pycache__/reward_functions.cpython-313.pyc b/ICL/RL/__pycache__/reward_functions.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..de848f41dbc4ad14ed37e578e31731de90790759 Binary files /dev/null and b/ICL/RL/__pycache__/reward_functions.cpython-313.pyc differ diff --git a/ICL/RL/__pycache__/train_grpo.cpython-313.pyc b/ICL/RL/__pycache__/train_grpo.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..93034495e2421e861696d833f2dd11a25e9e4604 Binary files /dev/null and b/ICL/RL/__pycache__/train_grpo.cpython-313.pyc differ diff --git a/ICL/RL/build_rl_dataset.py b/ICL/RL/build_rl_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..2c4a2061b1260250aba5ffbd82a6ae5d5065865f --- /dev/null +++ b/ICL/RL/build_rl_dataset.py @@ -0,0 +1,420 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Build RL training dataset with 50:50 Positive/Negative split. + +Positive Set (必须检索类): + - OK-VQA, A-OKVQA: 知识密集型问答 + - ScienceQA: 科学常识推断 + +Negative Set (禁止/少检索类): + - TextVQA, OCR-VQA: 看图读字 + - VQA-v2: 基础视觉识别 + - CLEVR: 纯逻辑推理 + +Usage: + python build_rl_dataset.py \ + --dataset-root /workspace/xiaobin/M3IT \ + --output-dir /workspace/xiaobin/RL_data \ + --positive-samples 10000 \ + --negative-samples 10000 +""" + +import argparse +import base64 +import hashlib +import json +import os +import random +import re +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, List, Optional + +try: + from tqdm import tqdm +except ImportError: + tqdm = None + + +_B64_RE = re.compile(r"[^A-Za-z0-9+/=]") + + +@dataclass(frozen=True) +class QueryItem: + """Query item for RL training.""" + image_path: str + question: str + answer: str + subdir: str + uid: str + category: str # "positive" or "negative" + + +def _looks_like_base64(s: str) -> bool: + if s.startswith("data:image"): + return True + if len(s) > 200 and all(c.isalnum() or c in "+/=\n\r" for c in s[:200]): + return True + return False + + +def _b64_to_image_path(b64: str, cache_dir: Path, prefix: str) -> Optional[str]: + s = b64.strip() + if s.startswith("data:image"): + s = s.split(",", 1)[-1] + s = _B64_RE.sub("", s) + if not s: + return None + pad = len(s) % 4 + if pad: + s += "=" * (4 - pad) + try: + data = base64.b64decode(s) + except Exception: + return None + sha = hashlib.sha1(data).hexdigest() + cache_dir.mkdir(parents=True, exist_ok=True) + out = cache_dir / f"{prefix}_{sha}.jpg" + if not out.exists(): + with out.open("wb") as f: + f.write(data) + return str(out) + + +def _resolve_image_path(v: str, dataset_root: Path) -> Optional[str]: + p = Path(v) + if p.is_absolute(): + return str(p) + cand = dataset_root / v + if cand.exists(): + return str(cand) + return str(p) + + +def _extract_image_path(raw: Dict, dataset_root: Path, cache_dir: Path, prefix: str) -> Optional[str]: + keys = [ + "image", "image_path", "image_file", "image_str", "base64", + "img", "img_path", "img_str", "image_base64_str", "image_base64", + "image_base64s", "images", + ] + for k in keys: + if k not in raw: + continue + v = raw.get(k) + if isinstance(v, list) and v: + v = v[0] + if not isinstance(v, str) or not v.strip(): + continue + v = v.strip() + if _looks_like_base64(v): + return _b64_to_image_path(v, cache_dir, prefix) + return _resolve_image_path(v, dataset_root) + return None + + +def _extract_uid(raw: Dict, fallback: str) -> str: + for k in ("id", "image_id", "img_id", "question_id"): + v = raw.get(k) + if isinstance(v, (str, int)): + return str(v) + meta = raw.get("meta") if isinstance(raw.get("meta"), dict) else {} + for k in ("img_id", "id", "image_id"): + v = meta.get(k) + if isinstance(v, (str, int)): + return str(v) + return fallback + + +def _extract_question(raw: Dict) -> Optional[str]: + for k in ("text", "question", "query", "prompt", "input"): + v = raw.get(k) + if isinstance(v, str) and v.strip(): + return v.strip() + return None + + +def _extract_answer(raw: Dict) -> Optional[str]: + if "answers" in raw: + v = raw.get("answers") + if isinstance(v, list): + for a in v: + if isinstance(a, str) and a.strip(): + return a.strip() + for k in ("answer", "output", "label", "target", "paraphrased_answer", "original_answer"): + v = raw.get(k) + if isinstance(v, str) and v.strip(): + return v.strip() + return None + + +def discover_subdirs(dataset_root: Path, categories: List[str]) -> List[str]: + """Discover all subdirs under dataset_root/data/{category}.""" + out = [] + for cat in categories: + base = dataset_root / "data" / cat + if not base.exists(): + continue + for p in sorted(base.iterdir()): + if p.is_dir(): + out.append(f"{cat}/{p.name}") + return out + + +def find_split_file(subdir_dir: Path, split: str) -> Optional[Path]: + """Find jsonl file for given split.""" + if not subdir_dir.exists(): + return None + split = split.lower() + files = sorted(p for p in subdir_dir.iterdir() if p.suffix == ".jsonl") + if not files: + return None + + exact = [p for p in files if split in p.name.lower()] + if exact: + return exact[0] + + if split in ("train", "training"): + for key in ("train", "training"): + cand = [p for p in files if key in p.name.lower()] + if cand: + return cand[0] + + return files[0] + + +def load_instructions(dataset_root: Path, subdir: str) -> List[str]: + """Load instructions for a subdir.""" + base = dataset_root / "data" / subdir + for name in ("instructions.json", "instruction.json"): + path = base / name + if not path.exists(): + continue + with path.open("r", encoding="utf-8") as f: + try: + data = json.load(f) + except Exception: + return [] + if isinstance(data, list): + return [str(x).strip() for x in data if str(x).strip()] + if isinstance(data, dict): + for key in ("instructions", "instruction", "prompts"): + v = data.get(key) + if isinstance(v, list): + return [str(x).strip() for x in v if str(x).strip()] + if isinstance(v, str) and v.strip(): + return [v.strip()] + return [] + return [] + + +def build_query_pool_for_subdir( + dataset_root: Path, + subdir: str, + split: str, + cache_dir: Path, + scan_limit: int, + category: str, + show_progress: bool, +) -> List[QueryItem]: + """Build query pool for a single subdir.""" + subdir_dir = dataset_root / "data" / subdir + jsonl_path = find_split_file(subdir_dir, split) + if jsonl_path is None: + return [] + + out: List[QueryItem] = [] + prefix = subdir.replace("/", "_") + + with jsonl_path.open("r", encoding="utf-8") as f: + lines = [line for line in f if line.strip()] + + if show_progress and tqdm is not None: + lines = tqdm(lines, desc=f"load:{subdir}", unit="rec", mininterval=1.0) + + for idx, line in enumerate(lines): + if scan_limit > 0 and idx >= scan_limit: + break + + try: + raw = json.loads(line) + except Exception: + continue + + question = _extract_question(raw) + answer = _extract_answer(raw) + if not question or not answer: + continue + + img_path = _extract_image_path(raw, dataset_root, cache_dir, prefix) + if not img_path: + continue + + uid = _extract_uid(raw, f"{subdir}:{idx:08d}") + + out.append(QueryItem( + image_path=img_path, + question=question, + answer=answer, + subdir=subdir, + uid=uid, + category=category, + )) + + return out + + +def main() -> int: + ap = argparse.ArgumentParser(description="Build RL training dataset with 50:50 Positive/Negative split") + ap.add_argument("--dataset-root", default="/workspace/xiaobin/M3IT") + ap.add_argument("--output-dir", default="/workspace/xiaobin/RL_data") + ap.add_argument("--split", default="train") + ap.add_argument("--positive-samples", type=int, default=10000) + ap.add_argument("--negative-samples", type=int, default=10000) + ap.add_argument("--scan-limit", type=int, default=100000, help="Max records to scan per subdir") + ap.add_argument("--seed", type=int, default=42) + ap.add_argument("--no-progress", action="store_true") + args = ap.parse_args() + + rng = random.Random(args.seed) + dataset_root = Path(args.dataset_root) + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + cache_dir = output_dir / "_image_cache" + cache_dir.mkdir(parents=True, exist_ok=True) + show_progress = not args.no_progress + + # Define positive and negative subdirs + positive_subdirs_map = { + "vqa": ["okvqa", "a-okvqa", "science-qa"], + "reasoning": [], # Add if available + } + + negative_subdirs_map = { + "vqa": ["text-vqa", "ocr-vqa", "vqa-v2", "clevr"], + } + + print("[INFO] Discovering subdirs...") + + # Collect positive subdirs + positive_subdirs = [] + for cat, names in positive_subdirs_map.items(): + base = dataset_root / "data" / cat + if not base.exists(): + continue + for name in names: + subdir_path = base / name + if subdir_path.exists() and subdir_path.is_dir(): + positive_subdirs.append(f"{cat}/{name}") + + # Collect negative subdirs + negative_subdirs = [] + for cat, names in negative_subdirs_map.items(): + base = dataset_root / "data" / cat + if not base.exists(): + continue + for name in names: + subdir_path = base / name + if subdir_path.exists() and subdir_path.is_dir(): + negative_subdirs.append(f"{cat}/{name}") + + print(f"[INFO] Found {len(positive_subdirs)} positive subdirs: {positive_subdirs}") + print(f"[INFO] Found {len(negative_subdirs)} negative subdirs: {negative_subdirs}") + + if not positive_subdirs or not negative_subdirs: + print("[ERROR] Need both positive and negative subdirs") + return 1 + + # Build positive pool + print("\n[INFO] Building positive pool...") + positive_pool: List[QueryItem] = [] + for sd in positive_subdirs: + queries = build_query_pool_for_subdir( + dataset_root, sd, args.split, cache_dir, + args.scan_limit, "positive", show_progress + ) + positive_pool.extend(queries) + print(f" [{sd}] Loaded {len(queries)} queries") + + print(f"[INFO] Total positive pool: {len(positive_pool)}") + + # Build negative pool + print("\n[INFO] Building negative pool...") + negative_pool: List[QueryItem] = [] + for sd in negative_subdirs: + queries = build_query_pool_for_subdir( + dataset_root, sd, args.split, cache_dir, + args.scan_limit, "negative", show_progress + ) + negative_pool.extend(queries) + print(f" [{sd}] Loaded {len(queries)} queries") + + print(f"[INFO] Total negative pool: {len(negative_pool)}") + + # Sample + if len(positive_pool) < args.positive_samples: + print(f"[WARN] Only {len(positive_pool)} positive samples available, using all") + positive_samples = positive_pool + else: + positive_samples = rng.sample(positive_pool, args.positive_samples) + + if len(negative_pool) < args.negative_samples: + print(f"[WARN] Only {len(negative_pool)} negative samples available, using all") + negative_samples = negative_pool + else: + negative_samples = rng.sample(negative_pool, args.negative_samples) + + # Combine and shuffle + all_samples = positive_samples + negative_samples + rng.shuffle(all_samples) + + print(f"\n[INFO] Final dataset: {len(positive_samples)} positive + {len(negative_samples)} negative = {len(all_samples)} total") + + # Load instructions for each subdir + inst_map: Dict[str, List[str]] = {} + all_subdirs = set(q.subdir for q in all_samples) + for sd in all_subdirs: + inst_map[sd] = load_instructions(dataset_root, sd) + + # Write to jsonl + output_file = output_dir / "rl_train.jsonl" + with output_file.open("w", encoding="utf-8") as f: + for q in all_samples: + instructions = inst_map.get(q.subdir, []) + instruction = rng.choice(instructions) if instructions else "Please answer the question based on the image." + + rec = { + "id": q.uid, + "category": q.category, + "subdir": q.subdir, + "image": q.image_path, + "question": q.question, + "answer": q.answer, + "instruction": instruction, + } + f.write(json.dumps(rec, ensure_ascii=False) + "\n") + + print(f"\n[INFO] Saved to {output_file}") + + # Write statistics + stats_file = output_dir / "dataset_stats.json" + stats = { + "total_samples": len(all_samples), + "positive_samples": len(positive_samples), + "negative_samples": len(negative_samples), + "positive_subdirs": positive_subdirs, + "negative_subdirs": negative_subdirs, + "positive_distribution": {sd: sum(1 for q in positive_samples if q.subdir == sd) for sd in positive_subdirs}, + "negative_distribution": {sd: sum(1 for q in negative_samples if q.subdir == sd) for sd in negative_subdirs}, + } + with stats_file.open("w", encoding="utf-8") as f: + json.dump(stats, f, indent=2, ensure_ascii=False) + + print(f"[INFO] Saved statistics to {stats_file}") + print("\n[DONE] Dataset construction complete!") + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/ICL/RL/config.yaml b/ICL/RL/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..730f3287d1737ea5945f18f8d0961948c7c6a639 --- /dev/null +++ b/ICL/RL/config.yaml @@ -0,0 +1,82 @@ +# GRPO Training Configuration + +# Model paths +model: + vlm_path: "/workspace/xiaobin/SFT_model/hf_qwen3vl_siglip_vqa_iter_0000881" + siglip_path: "/workspace/siglip2-so400m-patch16-naflex" + output_dir: "/workspace/xiaobin/RL/output" + +# Training hyperparameters +training: + learning_rate: 1.0e-5 + per_device_train_batch_size: 1 + gradient_accumulation_steps: 4 + num_train_epochs: 3 + warmup_steps: 100 + max_grad_norm: 1.0 + bf16: true + +# GRPO specific +grpo: + num_generations: 8 # Group size for relative optimization + max_prompt_length: 512 + max_completion_length: 1024 + temperature: 0.7 + do_sample: true + +# Environment +environment: + max_turns: 3 # Maximum retrieval turns before timeout + self_exclusion: true # Exclude query image from retrieval + +# Reward function +reward: + # R_outcome + correct_reward: 1.0 + wrong_reward: -1.0 + + # R_penalty (step-based) + penalty_0_shot: 0.0 + penalty_1_shot: -0.1 + penalty_2_shot: -0.25 + penalty_3_shot: -0.5 + penalty_timeout: -2.0 + + # R_rel (SigLIP similarity scaling) + siglip_min_threshold: 0.33 # P05 of same-subdir pairs + siglip_max_threshold: 0.97 # P95 of same-subdir pairs + +# Data +data: + # Positive datasets (needs retrieval) + positive_datasets: + - "okvqa" + - "aokvqa" + - "science_qa" + + # Negative datasets (direct answer) + negative_datasets: + - "text_vqa" + - "ocr_vqa" + - "vqa_v2" + - "clevr" + + # Sampling + max_samples_per_dataset: null # null for full dataset, or set a number for debugging + shuffle: true + seed: 42 + +# Logging +logging: + use_wandb: true + wandb_project: "grpo-retrieval-vqa" + logging_steps: 10 + save_steps: 500 + save_total_limit: 3 + +# Hardware +hardware: + device: "cuda" + num_gpus: 8 + tensor_parallel: 8 + data_parallel: 1 diff --git a/ICL/RL/data_utils.py b/ICL/RL/data_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8e78b4a07f14c77a200939a55b318d85244c6dad --- /dev/null +++ b/ICL/RL/data_utils.py @@ -0,0 +1,192 @@ +""" +Data utilities for loading and preprocessing M3IT dataset. +Splits into 50% positive (needs retrieval) and 50% negative (direct answer). +""" +from datasets import load_dataset, concatenate_datasets +from typing import Dict, List, Tuple +import random + + +# Dataset categorization based on plan +POSITIVE_DATASETS = [ + "okvqa", # Knowledge-intensive VQA + "aokvqa", # A-OKVQA + "science_qa", # Science reasoning + # "viquae", # Entity knowledge (if available) +] + +NEGATIVE_DATASETS = [ + "text_vqa", # Read text from image + "ocr_vqa", # OCR-based VQA + "vqa_v2", # Basic visual recognition + "clevr", # Pure logic reasoning +] + + +def load_m3it_splits( + positive_datasets: List[str] = POSITIVE_DATASETS, + negative_datasets: List[str] = NEGATIVE_DATASETS, + split: str = "train", + max_samples_per_dataset: int = None +) -> Tuple[List[Dict], List[Dict]]: + """ + Load M3IT dataset and split into positive/negative sets. + + Args: + positive_datasets: List of dataset names that need retrieval + negative_datasets: List of dataset names that don't need retrieval + split: Dataset split (train/validation/test) + max_samples_per_dataset: Maximum samples per dataset (for debugging) + + Returns: + (positive_samples, negative_samples) + """ + positive_samples = [] + negative_samples = [] + + print(f"Loading M3IT datasets ({split} split)...") + + # Load positive datasets + for dataset_name in positive_datasets: + try: + ds = load_dataset("MMInstruction/M3IT", dataset_name, split=split) + if max_samples_per_dataset: + ds = ds.select(range(min(len(ds), max_samples_per_dataset))) + + samples = [ + { + "image": sample["image"], + "question": sample["instruction"], + "answer": sample["outputs"], + "dataset": dataset_name, + "needs_retrieval": True, + "id": f"{dataset_name}_{i}" + } + for i, sample in enumerate(ds) + ] + positive_samples.extend(samples) + print(f" ✓ {dataset_name}: {len(samples)} samples (positive)") + except Exception as e: + print(f" ✗ {dataset_name}: Failed to load - {e}") + + # Load negative datasets + for dataset_name in negative_datasets: + try: + ds = load_dataset("MMInstruction/M3IT", dataset_name, split=split) + if max_samples_per_dataset: + ds = ds.select(range(min(len(ds), max_samples_per_dataset))) + + samples = [ + { + "image": sample["image"], + "question": sample["instruction"], + "answer": sample["outputs"], + "dataset": dataset_name, + "needs_retrieval": False, + "id": f"{dataset_name}_{i}" + } + for i, sample in enumerate(ds) + ] + negative_samples.extend(samples) + print(f" ✓ {dataset_name}: {len(samples)} samples (negative)") + except Exception as e: + print(f" ✗ {dataset_name}: Failed to load - {e}") + + print(f"\nTotal: {len(positive_samples)} positive, {len(negative_samples)} negative") + return positive_samples, negative_samples + + +def create_balanced_dataset( + positive_samples: List[Dict], + negative_samples: List[Dict], + shuffle: bool = True, + seed: int = 42 +) -> List[Dict]: + """ + Create balanced 50:50 dataset by sampling. + + Args: + positive_samples: Positive samples (needs retrieval) + negative_samples: Negative samples (direct answer) + shuffle: Whether to shuffle the combined dataset + seed: Random seed + + Returns: + Balanced dataset + """ + # Balance to 50:50 + min_size = min(len(positive_samples), len(negative_samples)) + + random.seed(seed) + balanced_positive = random.sample(positive_samples, min_size) + balanced_negative = random.sample(negative_samples, min_size) + + # Combine + balanced_dataset = balanced_positive + balanced_negative + + if shuffle: + random.shuffle(balanced_dataset) + + print(f"Created balanced dataset: {len(balanced_dataset)} samples (50% pos, 50% neg)") + return balanced_dataset + + +def build_candidate_pools( + positive_samples: List[Dict], + negative_samples: List[Dict] +) -> Dict[str, List[Dict]]: + """ + Build candidate pools for retrieval, organized by dataset. + Each pool contains all images from that dataset for sub-dataset retrieval. + + Args: + positive_samples: Positive samples + negative_samples: Negative samples + + Returns: + Dictionary mapping dataset name to list of candidates + """ + all_samples = positive_samples + negative_samples + pools = {} + + for sample in all_samples: + dataset_name = sample["dataset"] + if dataset_name not in pools: + pools[dataset_name] = [] + + pools[dataset_name].append({ + "image": sample["image"], + "caption": sample["answer"], # Use answer as caption + "id": sample["id"] + }) + + print(f"\nBuilt candidate pools for {len(pools)} datasets:") + for name, pool in pools.items(): + print(f" {name}: {len(pool)} candidates") + + return pools + + +def format_prompt(sample: Dict, include_system_prompt: bool = True) -> str: + """ + Format a sample into a prompt for the model. + + Args: + sample: Sample dictionary + include_system_prompt: Whether to include system instructions + + Returns: + Formatted prompt string + """ + system_prompt = "" + if include_system_prompt: + system_prompt = ( + "You are a visual question answering assistant. " + "You can either answer directly using or retrieve similar images using .\n" + "- Use followed by a description when you need to see similar examples.\n" + "- Use followed by your answer when you're ready to answer.\n\n" + ) + + user_prompt = f"Question: {sample['question']}\n" + + return system_prompt + user_prompt diff --git a/ICL/RL/environment.py b/ICL/RL/environment.py new file mode 100644 index 0000000000000000000000000000000000000000..d7cc033aa9c4545ff957e4cc56ba73b7da32baad --- /dev/null +++ b/ICL/RL/environment.py @@ -0,0 +1,240 @@ +""" +Environment for GRPO rollout with SigLIP-based retrieval. +Implements the while-loop interaction: -> retrieve top-1 -> inject -> continue +""" +import torch +import torch.nn.functional as F +from typing import List, Dict, Any, Tuple, Optional +from PIL import Image +import numpy as np + + +class RetrievalEnvironment: + """ + Environment that handles the retrieval loop during GRPO rollout. + """ + + def __init__( + self, + siglip_model, + siglip_processor, + candidate_pools: Dict[str, List[Dict]], + max_turns: int = 3, + device: str = "cuda" + ): + """ + Args: + siglip_model: SigLIP vision-language model + siglip_processor: SigLIP processor + candidate_pools: Dict mapping dataset name to list of candidate images + Each candidate: {"image": PIL.Image, "caption": str, "id": str} + max_turns: Maximum retrieval turns before timeout + device: Device for computation + """ + self.siglip_model = siglip_model + self.siglip_processor = siglip_processor + self.candidate_pools = candidate_pools + self.max_turns = max_turns + self.device = device + + # Precompute image embeddings for each pool + self.pool_embeddings = {} + self._precompute_embeddings() + + def _precompute_embeddings(self): + """Precompute SigLIP embeddings for all candidate images.""" + print("Precomputing SigLIP embeddings for candidate pools...") + self.siglip_model.eval() + + # Get the device of the first parameter of siglip_model + model_device = next(self.siglip_model.parameters()).device + + for dataset_name, candidates in self.candidate_pools.items(): + images = [c["image"] for c in candidates] + embeddings = [] + + # Process in batches + batch_size = 32 + for i in range(0, len(images), batch_size): + batch_images = images[i:i + batch_size] + inputs = self.siglip_processor( + images=batch_images, + return_tensors="pt", + padding=True + ) + + # Move inputs to the same device as siglip_model + inputs = {k: v.to(model_device) for k, v in inputs.items()} + + with torch.no_grad(): + image_embeds = self.siglip_model.get_image_features(**inputs) + # Extract tensor from output object if needed + if hasattr(image_embeds, 'pooler_output'): + image_embeds = image_embeds.pooler_output + elif not isinstance(image_embeds, torch.Tensor): + image_embeds = image_embeds[0] + image_embeds = F.normalize(image_embeds, p=2, dim=-1) + embeddings.append(image_embeds.cpu()) + + self.pool_embeddings[dataset_name] = torch.cat(embeddings, dim=0) + print(f" {dataset_name}: {len(candidates)} images") + + def retrieve_top1( + self, + query_text: str, + dataset_name: str, + exclude_id: str + ) -> Tuple[Dict, float]: + """ + Retrieve top-1 most similar image based on query text. + + Args: + query_text: Description text from model's output + dataset_name: Which candidate pool to search + exclude_id: Image ID to exclude (the query image itself) + + Returns: + (retrieved_candidate, similarity_score) + """ + # Encode query text + text_inputs = self.siglip_processor( + text=[query_text], + return_tensors="pt", + padding=True + ) + + # Move inputs to the same device as siglip_model + # Get the device of the first parameter of siglip_model + model_device = next(self.siglip_model.parameters()).device + text_inputs = {k: v.to(model_device) for k, v in text_inputs.items()} + + with torch.no_grad(): + text_embeds = self.siglip_model.get_text_features(**text_inputs) + # Extract tensor from output object if needed + if hasattr(text_embeds, 'pooler_output'): + text_embeds = text_embeds.pooler_output + elif not isinstance(text_embeds, torch.Tensor): + text_embeds = text_embeds[0] + text_embeds = F.normalize(text_embeds, p=2, dim=-1) + + # Compute similarities - move image_embeds to same device as text_embeds + image_embeds = self.pool_embeddings[dataset_name].to(text_embeds.device) + similarities = (text_embeds @ image_embeds.T).squeeze(0) + + # Get candidates + candidates = self.candidate_pools[dataset_name] + + # Exclude query image + valid_indices = [i for i, c in enumerate(candidates) if c["id"] != exclude_id] + valid_similarities = similarities[valid_indices] + + # Get top-1 + top_idx = valid_similarities.argmax().item() + actual_idx = valid_indices[top_idx] + top_score = valid_similarities[top_idx].item() + + return candidates[actual_idx], top_score + + def compute_image_similarity( + self, + query_image: Image.Image, + retrieved_image: Image.Image + ) -> float: + """ + Compute SigLIP image-to-image similarity for R_rel reward. + """ + inputs = self.siglip_processor( + images=[query_image, retrieved_image], + return_tensors="pt", + padding=True + ) + + # Move inputs to the same device as siglip_model + model_device = next(self.siglip_model.parameters()).device + inputs = {k: v.to(model_device) for k, v in inputs.items()} + + with torch.no_grad(): + image_embeds = self.siglip_model.get_image_features(**inputs) + # Extract tensor from output object if needed + if hasattr(image_embeds, 'pooler_output'): + image_embeds = image_embeds.pooler_output + elif not isinstance(image_embeds, torch.Tensor): + image_embeds = image_embeds[0] + image_embeds = F.normalize(image_embeds, p=2, dim=-1) + + similarity = (image_embeds[0] @ image_embeds[1]).item() + return similarity + + def rollout( + self, + model, + tokenizer, + initial_prompt: str, + query_image: Image.Image, + query_id: str, + dataset_name: str, + generation_kwargs: Dict[str, Any] + ) -> Tuple[str, List[float], bool]: + """ + Execute one rollout trajectory with retrieval loop. + + Args: + model: VLM model + tokenizer: Tokenizer + initial_prompt: Initial prompt with image and question + query_image: Query image + query_id: Query image ID for self-exclusion + dataset_name: Dataset name for candidate pool + generation_kwargs: Generation parameters + + Returns: + (full_trajectory, siglip_scores, timeout) + """ + conversation_history = initial_prompt + siglip_scores = [] + timeout = False + + for turn in range(self.max_turns): + # Generate model output + output = model.generate( + conversation_history, + **generation_kwargs + ) + + # Check if model outputs + if "" in output: + conversation_history += output + break + + # Check if model outputs + if "" in output: + # Extract description text after + description = output.split("")[-1].strip() + + # Retrieve top-1 similar image + retrieved_candidate, _ = self.retrieve_top1( + query_text=description, + dataset_name=dataset_name, + exclude_id=query_id + ) + + # Compute image-to-image similarity for reward + img_similarity = self.compute_image_similarity( + query_image, + retrieved_candidate["image"] + ) + siglip_scores.append(img_similarity) + + # Inject retrieved image into conversation + system_response = f"\nSystem: Retrieved Image - {retrieved_candidate['caption']}\n" + conversation_history += output + system_response + else: + # Model output neither nor , treat as direct answer + conversation_history += output + break + + # Check timeout + if turn == self.max_turns - 1 and "" not in conversation_history: + timeout = True + + return conversation_history, siglip_scores, timeout diff --git a/ICL/RL/inference_example.py b/ICL/RL/inference_example.py new file mode 100644 index 0000000000000000000000000000000000000000..e65dec72b243c9fb765c1dc8c466bceb331ebd9b --- /dev/null +++ b/ICL/RL/inference_example.py @@ -0,0 +1,157 @@ +""" +Example inference script for the trained GRPO model. +Shows how to use the model for retrieval-augmented VQA. +""" +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor +from PIL import Image +from environment import RetrievalEnvironment + + +def load_trained_model(model_path: str, device: str = "cuda"): + """Load the trained GRPO model.""" + print(f"Loading model from {model_path}...") + model = AutoModelForCausalLM.from_pretrained( + model_path, + torch_dtype=torch.bfloat16, + device_map="auto", + trust_remote_code=True + ) + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + return model, tokenizer + + +def inference_with_retrieval( + model, + tokenizer, + environment: RetrievalEnvironment, + image_path: str, + question: str, + dataset_name: str, + image_id: str, + max_new_tokens: int = 512 +): + """ + Run inference with retrieval capability. + + Args: + model: Trained VLM model + tokenizer: Tokenizer + environment: Retrieval environment + image_path: Path to query image + question: Question to answer + dataset_name: Dataset name for candidate pool + image_id: Image ID for self-exclusion + max_new_tokens: Maximum tokens to generate + + Returns: + Final answer and retrieval history + """ + # Load image + image = Image.open(image_path).convert("RGB") + + # Format prompt + prompt = ( + "You are a visual question answering assistant. " + "You can either answer directly using or retrieve similar images using .\n" + f"Question: {question}\n" + ) + + # Execute rollout + trajectory, siglip_scores, timeout = environment.rollout( + model=model, + tokenizer=tokenizer, + initial_prompt=prompt, + query_image=image, + query_id=image_id, + dataset_name=dataset_name, + generation_kwargs={ + "max_new_tokens": max_new_tokens, + "temperature": 0.7, + "do_sample": True, + } + ) + + # Extract final answer + if "" in trajectory: + answer = trajectory.split("")[-1].strip() + else: + answer = "No answer generated (timeout)" + + # Count retrieval steps + num_retrievals = trajectory.count("") + + return { + "answer": answer, + "num_retrievals": num_retrievals, + "siglip_scores": siglip_scores, + "timeout": timeout, + "full_trajectory": trajectory + } + + +def main(): + """Example usage.""" + # Configuration + MODEL_PATH = "/workspace/xiaobin/RL/output/final_model" + SIGLIP_MODEL = "/workspace/siglip2-so400m-patch16-naflex" + + # Example query + IMAGE_PATH = "path/to/your/image.jpg" + QUESTION = "What is the capital of the country shown in this image?" + DATASET_NAME = "okvqa" # Which candidate pool to use + IMAGE_ID = "example_001" + + print("="*80) + print("GRPO MODEL INFERENCE EXAMPLE") + print("="*80) + print() + + # Load model + model, tokenizer = load_trained_model(MODEL_PATH) + + # Load SigLIP for retrieval + print(f"Loading SigLIP from {SIGLIP_MODEL}...") + from transformers import AutoModel + siglip_model = AutoModel.from_pretrained(SIGLIP_MODEL).to("cuda") + siglip_processor = AutoProcessor.from_pretrained(SIGLIP_MODEL) + + # Initialize environment (you need to provide candidate pools) + # For this example, we'll skip the full environment setup + print("\nNote: Full environment setup requires candidate pools.") + print("See train_grpo.py for complete example.") + print() + + # Example without environment (direct generation) + print("Running direct generation (without retrieval)...") + print(f"Question: {question}") + + # Format prompt + prompt = ( + "You are a visual question answering assistant. " + "Answer the following question about the image.\n" + f"Question: {QUESTION}\n" + "Answer:" + ) + + # Generate + inputs = tokenizer(prompt, return_tensors="pt").to("cuda") + outputs = model.generate( + **inputs, + max_new_tokens=128, + temperature=0.7, + do_sample=True + ) + answer = tokenizer.decode(outputs[0], skip_special_tokens=True) + + print(f"Answer: {answer}") + print() + + print("="*80) + print("For full retrieval-augmented inference, use the RetrievalEnvironment") + print("with properly initialized candidate pools.") + print("="*80) + + +if __name__ == "__main__": + main() diff --git a/ICL/RL/key_metrics_20260220_152053.log b/ICL/RL/key_metrics_20260220_152053.log new file mode 100644 index 0000000000000000000000000000000000000000..f75e9907e628cd26d95e6478a340e1d0b62f6a94 --- /dev/null +++ b/ICL/RL/key_metrics_20260220_152053.log @@ -0,0 +1,10 @@ + 0%| | 1/1875 [00:32<16:39:52, 32.01s/it] 0%| | 2/1875 [01:16<20:27:46, 39.33s/it] 0%| | 3/1875 [02:00<21:34:29, 41.49s/it] 0%| | 4/1875 [02:22<17:39:25, 33.97s/it] 0%| | 5/1875 [03:04<19:02:40, 36.66s/it] 0%| | 6/1875 [03:47<20:08:12, 38.79s/it] 0%| | 7/1875 [04:24<19:52:10, 38.29s/it] 0%| | 8/1875 [05:06<20:28:55, 39.49s/it] 0%| | 9/1875 [05:53<21:37:22, 41.72s/it] 1%| | 10/1875 [06:23<19:46:45, 38.18s/it] 1%| | 10/1875 [06:23<19:46:45, 38.18s/it] 1%| | 11/1875 [07:03<20:03:40, 38.74s/it] 1%| | 12/1875 [07:33<18:38:14, 36.01s/it] 1%| | 13/1875 [07:57<16:42:30, 32.30s/it] 1%| | 14/1875 [08:24<15:53:26, 30.74s/it] 1%| | 15/1875 [08:54<15:50:11, 30.65s/it] 1%| | 16/1875 [09:24<15:44:34, 30.49s/it] 1%| | 17/1875 [09:51<15:08:09, 29.33s/it] 1%| | 18/1875 [10:12<13:54:45, 26.97s/it] 1%| | 19/1875 [10:52<15:49:15, 30.69s/it] 1%| | 20/1875 [11:20<15:22:48, 29.85s/it] 1%| | 20/1875 [11:20<15:22:48, 29.85s/it] 1%| | 21/1875 [11:44<14:27:00, 28.06s/it] 1%| | 22/1875 [12:09<14:04:22, 27.34s/it] 1%| | 23/1875 [12:52<16:26:19, 31.95s/it] 1%|▏ | 24/1875 [13:14<14:54:48, 29.00s/it] 1%|▏ | 25/1875 [13:59<17:22:59, 33.83s/it] 1%|▏ | 26/1875 [14:30<16:58:17, 33.04s/it] 1%|▏ | 27/1875 [15:10<17:55:23, 34.92s/it] 1%|▏ | 28/1875 [15:40<17:15:56, 33.65s/it] 2%|▏ | 29/1875 [16:12<16:53:40, 32.95s/it] 2%|▏ | 30/1875 [16:44<16:44:30, 32.67s/it] 2%|▏ | 30/1875 [16:44<16:44:30, 32.67s/it] 2%|▏ | 31/1875 [17:29<18:37:07, 36.35s/it] 2%|▏ | 32/1875 [17:54<16:52:23, 32.96s/it] 2%|▏ | 33/1875 [18:26<16:46:40, 32.79s/it] 2%|▏ | 34/1875 [19:03<17:27:09, 34.13s/it] 2%|▏ | 35/1875 [19:48<19:07:40, 37.42s/it] 2%|▏ | 36/1875 [20:12<16:57:07, 33.19s/it] 2%|▏ | 37/1875 [20:54<18:21:52, 35.97s/it] 2%|▏ | 38/1875 [21:34<18:53:11, 37.01s/it] 2%|▏ | 39/1875 [22:07<18:19:29, 35.93s/it] 2%|▏ | 40/1875 [22:47<18:52:31, 37.03s/it] 2%|▏ | 40/1875 [22:47<18:52:31, 37.03s/it] 2%|▏ | 41/1875 [23:26<19:14:46, 37.78s/it] 2%|▏ | 42/1875 [23:58<18:17:29, 35.92s/it] 2%|▏ | 43/1875 [24:40<19:17:58, 37.92s/it] 2%|▏ | 44/1875 [25:20<19:31:01, 38.37s/it] 2%|▏ | 45/1875 [26:05<20:38:02, 40.59s/it] 2%|▏ | 46/1875 [26:52<21:28:49, 42.28s/it] 3%|▎ | 47/1875 [27:39<22:11:46, 43.71s/it] 3%|▎ | 48/1875 [28:22<22:11:17, 43.72s/it] 3%|▎ | 49/1875 [28:51<19:49:53, 39.10s/it] 3%|▎ | 50/1875 [29:30<19:53:24, 39.24s/it] 3%|▎ | 50/1875 [29:30<19:53:24, 39.24s/it] 3%|▎ | 51/1875 [30:09<19:48:00, 39.08s/it] 3%|▎ | 52/1875 [30:54<20:37:08, 40.72s/it] 3%|▎ | 53/1875 [31:33<20:20:26, 40.19s/it] 3%|▎ | 54/1875 [32:13<20:24:26, 40.34s/it] 3%|▎ | 55/1875 [32:41<18:25:36, 36.45s/it] 3%|▎ | 56/1875 [33:05<16:36:53, 32.88s/it] 3%|▎ | 57/1875 [33:42<17:14:54, 34.16s/it] 3%|▎ | 58/1875 [34:22<18:03:37, 35.78s/it] 3%|▎ | 59/1875 [35:04<18:59:14, 37.64s/it] 3%|▎ | 60/1875 [35:31<17:23:41, 34.50s/it] 3%|▎ | 60/1875 [35:31<17:23:41, 34.50s/it] 3%|▎ | 61/1875 [35:57<16:08:17, 32.03s/it] 3%|▎ | 62/1875 [36:40<17:48:49, 35.37s/it] 3%|▎ | 63/1875 [37:19<18:19:46, 36.42s/it] 3%|▎ | 64/1875 [38:00<19:01:09, 37.81s/it] 3%|▎ | 65/1875 [38:41<19:26:56, 38.68s/it] 4%|▎ | 66/1875 [39:21<19:37:09, 39.04s/it] 4%|▎ | 67/1875 [40:00<19:38:11, 39.10s/it] 4%|▎ | 68/1875 [40:44<20:17:55, 40.44s/it] 4%|▎ | 69/1875 [41:25<20:23:31, 40.65s/it] 4%|▎ | 70/1875 [41:57<19:01:00, 37.93s/it] 4%|▎ | 70/1875 [41:57<19:01:00, 37.93s/it] 4%|▍ | 71/1875 [42:37<19:23:13, 38.69s/it] 4%|▍ | 72/1875 [43:03<17:31:05, 34.98s/it] 4%|▍ | 73/1875 [43:26<15:40:11, 31.30s/it] 4%|▍ | 74/1875 [43:50<14:36:17, 29.19s/it] 4%|▍ | 75/1875 [44:18<14:21:09, 28.71s/it] 4%|▍ | 76/1875 [44:59<16:11:41, 32.41s/it] 4%|▍ | 77/1875 [45:38<17:11:01, 34.41s/it] 4%|▍ | 78/1875 [46:20<18:18:18, 36.67s/it] 4%|▍ | 79/1875 [46:54<17:52:53, 35.84s/it] 4%|▍ | 80/1875 [47:34<18:28:54, 37.07s/it] 4%|▍ | 80/1875 [47:34<18:28:54, 37.07s/it] 4%|▍ | 81/1875 [47:58<16:31:27, 33.16s/it] 4%|▍ | 82/1875 [48:42<18:12:46, 36.57s/it] 4%|▍ | 83/1875 [49:19<18:17:12, 36.74s/it] 4%|▍ | 84/1875 [49:45<16:37:36, 33.42s/it] 5%|▍ | 85/1875 [50:12<15:40:58, 31.54s/it] 5%|▍ | 86/1875 [50:57<17:36:00, 35.42s/it] 5%|▍ | 87/1875 [51:24<16:22:18, 32.96s/it] 5%|▍ | 88/1875 [52:03<17:18:33, 34.87s/it] 5%|▍ | 89/1875 [52:28<15:46:47, 31.81s/it] 5%|▍ | 90/1875 [52:55<15:07:35, 30.51s/it] 5%|▍ | 90/1875 [52:55<15:07:35, 30.51s/it] 5%|▍ | 91/1875 [53:24<14:48:09, 29.87s/it] 5%|▍ | 92/1875 [53:50<14:16:30, 28.82s/it] 5%|▍ | 93/1875 [54:16<13:50:31, 27.96s/it] 5%|▌ | 94/1875 [54:44<13:46:50, 27.86s/it] 5%|▌ | 95/1875 [55:18<14:44:24, 29.81s/it] 5%|▌ | 96/1875 [55:42<13:49:10, 27.97s/it] 5%|▌ | 97/1875 [56:27<16:18:41, 33.03s/it] 5%|▌ | 98/1875 [56:51<15:02:24, 30.47s/it] 5%|▌ | 99/1875 [57:35<17:03:00, 34.56s/it] 5%|▌ | 100/1875 [58:10<17:08:10, 34.76s/it] 5%|▌ | 100/1875 [58:10<17:08:10, 34.76s/it]{'loss': '-0.05256', 'grad_norm': '5.906', 'learning_rate': '9e-07', 'num_tokens': '3.419e+05', 'completions/mean_length': '28.2', 'completions/min_length': '7.1', 'completions/max_length': '203.4', 'completions/clipped_ratio': '0.01094', 'completions/mean_terminated_length': '25.66', 'completions/min_terminated_length': '7.1', 'completions/max_terminated_length': '96.7', 'rewards/reward_function/mean': '-0.6679', 'rewards/reward_function/std': '0.5618', 'reward': '-0.6679', 'reward_std': '0.5618', 'frac_reward_zero_std': '0.01875', 'kl': '0.000573', 'entropy': '1.521', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '38.24', 'epoch': '0.016'} +{'loss': '0.006149', 'grad_norm': '7.188', 'learning_rate': '1.9e-06', 'num_tokens': '6.994e+05', 'completions/mean_length': '25.21', 'completions/min_length': '7', 'completions/max_length': '97.2', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '25.21', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '97.2', 'rewards/reward_function/mean': '-0.6674', 'rewards/reward_function/std': '0.5884', 'reward': '-0.6674', 'reward_std': '0.5884', 'frac_reward_zero_std': '0.025', 'kl': '0.0006873', 'entropy': '0.5141', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '29.57', 'epoch': '0.032'} +{'loss': '0.00594', 'grad_norm': '5.969', 'learning_rate': '2.9e-06', 'num_tokens': '1.024e+06', 'completions/mean_length': '29.12', 'completions/min_length': '7.5', 'completions/max_length': '143', 'completions/clipped_ratio': '0.004687', 'completions/mean_terminated_length': '28.04', 'completions/min_terminated_length': '7.5', 'completions/max_terminated_length': '100.2', 'rewards/reward_function/mean': '-0.6755', 'rewards/reward_function/std': '0.5826', 'reward': '-0.6755', 'reward_std': '0.5826', 'frac_reward_zero_std': '0.03125', 'kl': '0.002043', 'entropy': '0.8444', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '32.29', 'epoch': '0.048'} +{'loss': '-0.02556', 'grad_norm': '7.344', 'learning_rate': '3.9e-06', 'num_tokens': '1.341e+06', 'completions/mean_length': '32.32', 'completions/min_length': '7.6', 'completions/max_length': '191', 'completions/clipped_ratio': '0.007812', 'completions/mean_terminated_length': '30.55', 'completions/min_terminated_length': '7.6', 'completions/max_terminated_length': '129.5', 'rewards/reward_function/mean': '-0.6495', 'rewards/reward_function/std': '0.6234', 'reward': '-0.6495', 'reward_std': '0.6234', 'frac_reward_zero_std': '0.05625', 'kl': '0.006311', 'entropy': '1.191', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '36.22', 'epoch': '0.064'} +{'loss': '-0.01995', 'grad_norm': '5.094', 'learning_rate': '4.9e-06', 'num_tokens': '1.695e+06', 'completions/mean_length': '35.92', 'completions/min_length': '7.3', 'completions/max_length': '226.1', 'completions/clipped_ratio': '0.02031', 'completions/mean_terminated_length': '31.4', 'completions/min_terminated_length': '7.3', 'completions/max_terminated_length': '103.7', 'rewards/reward_function/mean': '-0.5803', 'rewards/reward_function/std': '0.6009', 'reward': '-0.5803', 'reward_std': '0.6009', 'frac_reward_zero_std': '0.0875', 'kl': '0.02276', 'entropy': '1.659', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '40.29', 'epoch': '0.08'} +{'loss': '0.08347', 'grad_norm': '7.656', 'learning_rate': '5e-06', 'num_tokens': '2.039e+06', 'completions/mean_length': '39.98', 'completions/min_length': '9.5', 'completions/max_length': '208.9', 'completions/clipped_ratio': '0.02969', 'completions/mean_terminated_length': '33.43', 'completions/min_terminated_length': '9.5', 'completions/max_terminated_length': '122.4', 'rewards/reward_function/mean': '-0.3908', 'rewards/reward_function/std': '0.6835', 'reward': '-0.3908', 'reward_std': '0.6835', 'frac_reward_zero_std': '0.2062', 'kl': '0.07372', 'entropy': '1.91', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '35.99', 'epoch': '0.096'} +{'loss': '0.1118', 'grad_norm': '4.375', 'learning_rate': '5e-06', 'num_tokens': '2.398e+06', 'completions/mean_length': '37.39', 'completions/min_length': '8.1', 'completions/max_length': '222.2', 'completions/clipped_ratio': '0.02031', 'completions/mean_terminated_length': '32.87', 'completions/min_terminated_length': '8.1', 'completions/max_terminated_length': '113.3', 'rewards/reward_function/mean': '-0.3459', 'rewards/reward_function/std': '0.7209', 'reward': '-0.3459', 'reward_std': '0.7209', 'frac_reward_zero_std': '0.1812', 'kl': '0.05823', 'entropy': '1.797', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '38.46', 'epoch': '0.112'} +{'loss': '0.08823', 'grad_norm': '3.328', 'learning_rate': '5e-06', 'num_tokens': '2.736e+06', 'completions/mean_length': '33.79', 'completions/min_length': '9.6', 'completions/max_length': '181.3', 'completions/clipped_ratio': '0.01094', 'completions/mean_terminated_length': '31.35', 'completions/min_terminated_length': '9.6', 'completions/max_terminated_length': '113.7', 'rewards/reward_function/mean': '-0.2964', 'rewards/reward_function/std': '0.7092', 'reward': '-0.2964', 'reward_std': '0.7092', 'frac_reward_zero_std': '0.2562', 'kl': '0.07571', 'entropy': '1.331', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '33.64', 'epoch': '0.128'} +{'loss': '0.0669', 'grad_norm': '3.906', 'learning_rate': '5e-06', 'num_tokens': '3.073e+06', 'completions/mean_length': '30.91', 'completions/min_length': '9.5', 'completions/max_length': '156.7', 'completions/clipped_ratio': '0.009375', 'completions/mean_terminated_length': '28.78', 'completions/min_terminated_length': '9.5', 'completions/max_terminated_length': '95.5', 'rewards/reward_function/mean': '-0.2708', 'rewards/reward_function/std': '0.7212', 'reward': '-0.2708', 'reward_std': '0.7212', 'frac_reward_zero_std': '0.3125', 'kl': '0.06573', 'entropy': '1.085', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '32.09', 'epoch': '0.144'} +{'loss': '0.04796', 'grad_norm': '8.562', 'learning_rate': '5e-06', 'num_tokens': '3.432e+06', 'completions/mean_length': '27.49', 'completions/min_length': '9.2', 'completions/max_length': '127', 'completions/clipped_ratio': '0.003125', 'completions/mean_terminated_length': '26.78', 'completions/min_terminated_length': '9.2', 'completions/max_terminated_length': '90.9', 'rewards/reward_function/mean': '-0.2074', 'rewards/reward_function/std': '0.7264', 'reward': '-0.2074', 'reward_std': '0.7264', 'frac_reward_zero_std': '0.325', 'kl': '0.06459', 'entropy': '0.6255', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '31.41', 'epoch': '0.16'} diff --git a/ICL/RL/key_metrics_20260224_094601.log b/ICL/RL/key_metrics_20260224_094601.log new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ICL/RL/key_metrics_20260224_133510.log b/ICL/RL/key_metrics_20260224_133510.log new file mode 100644 index 0000000000000000000000000000000000000000..1a0aef69a42029f45745e10c01f659ea10a5dda9 --- /dev/null +++ b/ICL/RL/key_metrics_20260224_133510.log @@ -0,0 +1,188 @@ + 0%| | 1/1875 [00:30<15:57:39, 30.66s/it] 0%| | 2/1875 [01:14<20:07:48, 38.69s/it] 0%| | 3/1875 [01:58<21:22:39, 41.11s/it] 0%| | 4/1875 [02:21<17:30:18, 33.68s/it] 0%| | 5/1875 [03:02<18:53:55, 36.38s/it] 0%| | 6/1875 [03:45<20:00:26, 38.54s/it] 0%| | 7/1875 [04:22<19:48:13, 38.17s/it] 0%| | 8/1875 [05:04<20:25:12, 39.37s/it] 0%| | 9/1875 [05:51<21:35:34, 41.66s/it] 1%| | 10/1875 [06:21<19:44:42, 38.11s/it] 1%| | 10/1875 [06:21<19:44:42, 38.11s/it] 1%| | 11/1875 [06:56<19:15:00, 37.18s/it] 1%| | 12/1875 [07:25<17:55:43, 34.64s/it] 1%| | 13/1875 [08:08<19:14:21, 37.20s/it] 1%| | 14/1875 [08:37<17:54:24, 34.64s/it] 1%| | 15/1875 [09:22<19:35:35, 37.92s/it] 1%| | 16/1875 [09:50<18:05:47, 35.04s/it] 1%| | 17/1875 [10:19<17:06:33, 33.15s/it] 1%| | 18/1875 [10:45<16:00:32, 31.04s/it] 1%| | 19/1875 [11:27<17:42:11, 34.34s/it] 1%| | 20/1875 [11:59<17:19:11, 33.61s/it] 1%| | 20/1875 [11:59<17:19:11, 33.61s/it] 1%| | 21/1875 [12:35<17:37:14, 34.22s/it] 1%| | 22/1875 [13:16<18:40:07, 36.27s/it] 1%| | 23/1875 [13:58<19:35:43, 38.09s/it] 1%|▏ | 24/1875 [14:38<19:53:48, 38.70s/it] 1%|▏ | 25/1875 [15:07<18:17:05, 35.58s/it] 1%|▏ | 26/1875 [15:36<17:15:29, 33.60s/it] 1%|▏ | 27/1875 [15:59<15:44:07, 30.65s/it] 1%|▏ | 28/1875 [16:44<17:53:05, 34.86s/it] 2%|▏ | 29/1875 [17:28<19:14:14, 37.52s/it] 2%|▏ | 30/1875 [17:58<18:02:44, 35.21s/it] 2%|▏ | 30/1875 [17:58<18:02:44, 35.21s/it] 2%|▏ | 31/1875 [18:43<19:31:26, 38.12s/it] 2%|▏ | 32/1875 [19:13<18:21:51, 35.87s/it] 2%|▏ | 33/1875 [19:40<17:00:49, 33.25s/it] 2%|▏ | 34/1875 [20:18<17:44:20, 34.69s/it] 2%|▏ | 35/1875 [20:47<16:51:39, 32.99s/it] 2%|▏ | 36/1875 [21:15<16:05:07, 31.49s/it] 2%|▏ | 37/1875 [21:48<16:16:05, 31.86s/it] 2%|▏ | 38/1875 [22:27<17:22:56, 34.06s/it] 2%|▏ | 39/1875 [23:13<19:05:48, 37.44s/it] 2%|▏ | 40/1875 [23:52<19:19:58, 37.93s/it] 2%|▏ | 40/1875 [23:52<19:19:58, 37.93s/it] 2%|▏ | 41/1875 [24:32<19:37:05, 38.51s/it] 2%|▏ | 42/1875 [25:16<20:33:07, 40.36s/it] 2%|▏ | 43/1875 [25:59<20:50:39, 40.96s/it] 2%|▏ | 44/1875 [26:22<18:11:43, 35.77s/it] 2%|▏ | 45/1875 [26:51<17:04:29, 33.59s/it] 2%|▏ | 46/1875 [27:37<19:00:29, 37.41s/it] 3%|▎ | 47/1875 [28:24<20:26:28, 40.26s/it] 3%|▎ | 48/1875 [29:01<19:53:11, 39.19s/it] 3%|▎ | 49/1875 [29:46<20:49:16, 41.05s/it] 3%|▎ | 50/1875 [30:25<20:31:58, 40.50s/it] 3%|▎ | 50/1875 [30:25<20:31:58, 40.50s/it] 3%|▎ | 51/1875 [31:04<20:12:28, 39.88s/it] 3%|▎ | 52/1875 [31:34<18:41:28, 36.91s/it] 3%|▎ | 53/1875 [32:13<18:59:46, 37.53s/it] 3%|▎ | 54/1875 [32:53<19:25:02, 38.39s/it] 3%|▎ | 55/1875 [33:34<19:48:11, 39.17s/it] 3%|▎ | 56/1875 [34:16<20:09:13, 39.89s/it] 3%|▎ | 57/1875 [34:53<19:42:54, 39.04s/it] 3%|▎ | 58/1875 [35:32<19:46:30, 39.18s/it] 3%|▎ | 59/1875 [36:14<20:09:44, 39.97s/it] 3%|▎ | 60/1875 [36:47<19:04:25, 37.83s/it] 3%|▎ | 60/1875 [36:47<19:04:25, 37.83s/it] 3%|▎ | 61/1875 [37:30<19:52:24, 39.44s/it] 3%|▎ | 62/1875 [38:13<20:23:31, 40.49s/it] 3%|▎ | 63/1875 [38:52<20:06:00, 39.93s/it] 3%|▎ | 64/1875 [39:17<17:51:38, 35.50s/it] 3%|▎ | 65/1875 [39:57<18:33:12, 36.90s/it] 4%|▎ | 66/1875 [40:36<18:55:48, 37.67s/it] 4%|▎ | 67/1875 [41:15<19:04:58, 38.00s/it] 4%|▎ | 68/1875 [41:58<19:43:50, 39.31s/it] 4%|▎ | 69/1875 [42:39<20:01:42, 39.92s/it] 4%|▎ | 70/1875 [43:09<18:35:58, 37.10s/it] 4%|▎ | 70/1875 [43:09<18:35:58, 37.10s/it] 4%|▍ | 71/1875 [43:50<19:03:25, 38.03s/it] 4%|▍ | 72/1875 [44:30<19:20:40, 38.62s/it] 4%|▍ | 73/1875 [45:06<18:58:55, 37.92s/it] 4%|▍ | 74/1875 [45:47<19:29:33, 38.96s/it] 4%|▍ | 75/1875 [46:31<20:12:48, 40.43s/it] 4%|▍ | 76/1875 [47:12<20:17:13, 40.60s/it] 4%|▍ | 77/1875 [47:35<17:37:11, 35.28s/it] 4%|▍ | 78/1875 [48:04<16:40:53, 33.42s/it] 4%|▍ | 79/1875 [48:31<15:41:54, 31.47s/it] 4%|▍ | 80/1875 [49:11<17:01:15, 34.14s/it] 4%|▍ | 80/1875 [49:11<17:01:15, 34.14s/it] 4%|▍ | 81/1875 [49:36<15:31:25, 31.15s/it] 4%|▍ | 82/1875 [50:20<17:32:00, 35.20s/it] 4%|▍ | 83/1875 [50:57<17:46:01, 35.69s/it] 4%|▍ | 84/1875 [51:37<18:20:53, 36.88s/it] 5%|▍ | 85/1875 [52:22<19:37:01, 39.45s/it] 5%|▍ | 86/1875 [52:48<17:35:11, 35.39s/it] 5%|▍ | 87/1875 [53:25<17:49:11, 35.88s/it] 5%|▍ | 88/1875 [53:50<16:05:56, 32.43s/it] 5%|▍ | 89/1875 [54:15<15:06:44, 30.46s/it] 5%|▍ | 90/1875 [54:55<16:24:34, 33.10s/it] 5%|▍ | 90/1875 [54:55<16:24:34, 33.10s/it] 5%|▍ | 91/1875 [55:22<15:36:10, 31.49s/it] 5%|▍ | 92/1875 [55:54<15:35:57, 31.50s/it] 5%|▍ | 93/1875 [56:21<14:58:48, 30.26s/it] 5%|▌ | 94/1875 [56:51<14:52:15, 30.06s/it] 5%|▌ | 95/1875 [57:38<17:27:17, 35.30s/it] 5%|▌ | 96/1875 [58:22<18:43:17, 37.88s/it] 5%|▌ | 97/1875 [58:49<17:01:11, 34.46s/it] 5%|▌ | 98/1875 [59:11<15:12:29, 30.81s/it] 5%|▌ | 99/1875 [59:39<14:42:47, 29.82s/it] 5%|▌ | 100/1875 [1:00:19<16:15:14, 32.97s/it] 5%|▌ | 100/1875 [1:00:19<16:15:14, 32.97s/it]{'loss': '-0.03673', 'grad_norm': '5.344', 'learning_rate': '9e-07', 'num_tokens': '3.42e+05', 'completions/mean_length': '28.24', 'completions/min_length': '7.1', 'completions/max_length': '203.4', 'completions/clipped_ratio': '0.01094', 'completions/mean_terminated_length': '25.71', 'completions/min_terminated_length': '7.1', 'completions/max_terminated_length': '97.1', 'rewards/reward_function/mean': '-0.673', 'rewards/reward_function/std': '0.5587', 'reward': '-0.673', 'reward_std': '0.5587', 'frac_reward_zero_std': '0.01875', 'kl': '0.0005538', 'entropy': '1.539', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '38.03', 'epoch': '0.016'} +{'loss': '-0.03881', 'grad_norm': '4.562', 'learning_rate': '1.9e-06', 'num_tokens': '7.007e+05', 'completions/mean_length': '27.24', 'completions/min_length': '7', 'completions/max_length': '142.6', 'completions/clipped_ratio': '0.003125', 'completions/mean_terminated_length': '26.53', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '107', 'rewards/reward_function/mean': '-0.7122', 'rewards/reward_function/std': '0.5238', 'reward': '-0.7122', 'reward_std': '0.5238', 'frac_reward_zero_std': '0.04375', 'kl': '0.0006692', 'entropy': '0.8416', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '33.75', 'epoch': '0.032'} +{'loss': '0.03503', 'grad_norm': '6.938', 'learning_rate': '2.9e-06', 'num_tokens': '1.026e+06', 'completions/mean_length': '30.58', 'completions/min_length': '7.2', 'completions/max_length': '181', 'completions/clipped_ratio': '0.00625', 'completions/mean_terminated_length': '29.17', 'completions/min_terminated_length': '7.2', 'completions/max_terminated_length': '135.1', 'rewards/reward_function/mean': '-0.6582', 'rewards/reward_function/std': '0.607', 'reward': '-0.6582', 'reward_std': '0.607', 'frac_reward_zero_std': '0.0375', 'kl': '0.001262', 'entropy': '1.109', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '35.76', 'epoch': '0.048'} +{'loss': '-0.04466', 'grad_norm': '4.219', 'learning_rate': '3.9e-06', 'num_tokens': '1.344e+06', 'completions/mean_length': '33.78', 'completions/min_length': '7.6', 'completions/max_length': '184.2', 'completions/clipped_ratio': '0.01094', 'completions/mean_terminated_length': '31.36', 'completions/min_terminated_length': '7.6', 'completions/max_terminated_length': '125.5', 'rewards/reward_function/mean': '-0.6545', 'rewards/reward_function/std': '0.5966', 'reward': '-0.6545', 'reward_std': '0.5966', 'frac_reward_zero_std': '0.05625', 'kl': '0.003768', 'entropy': '1.312', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '35.32', 'epoch': '0.064'} +{'loss': '-0.01845', 'grad_norm': '4.281', 'learning_rate': '4.9e-06', 'num_tokens': '1.699e+06', 'completions/mean_length': '35.9', 'completions/min_length': '7.2', 'completions/max_length': '214.8', 'completions/clipped_ratio': '0.01719', 'completions/mean_terminated_length': '32.1', 'completions/min_terminated_length': '7.2', 'completions/max_terminated_length': '122.7', 'rewards/reward_function/mean': '-0.6193', 'rewards/reward_function/std': '0.5866', 'reward': '-0.6193', 'reward_std': '0.5866', 'frac_reward_zero_std': '0.09375', 'kl': '0.02321', 'entropy': '1.699', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '39.27', 'epoch': '0.08'} +{'loss': '0.03927', 'grad_norm': '6.688', 'learning_rate': '5e-06', 'num_tokens': '2.044e+06', 'completions/mean_length': '42.26', 'completions/min_length': '8.2', 'completions/max_length': '233', 'completions/clipped_ratio': '0.03438', 'completions/mean_terminated_length': '34.74', 'completions/min_terminated_length': '8.2', 'completions/max_terminated_length': '135.3', 'rewards/reward_function/mean': '-0.4009', 'rewards/reward_function/std': '0.6767', 'reward': '-0.4009', 'reward_std': '0.6767', 'frac_reward_zero_std': '0.175', 'kl': '0.06153', 'entropy': '2.341', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '38.07', 'epoch': '0.096'} +{'loss': '0.1242', 'grad_norm': '7.719', 'learning_rate': '5e-06', 'num_tokens': '2.403e+06', 'completions/mean_length': '38.81', 'completions/min_length': '8.8', 'completions/max_length': '222.6', 'completions/clipped_ratio': '0.02656', 'completions/mean_terminated_length': '32.89', 'completions/min_terminated_length': '8.8', 'completions/max_terminated_length': '112.7', 'rewards/reward_function/mean': '-0.3243', 'rewards/reward_function/std': '0.725', 'reward': '-0.3243', 'reward_std': '0.725', 'frac_reward_zero_std': '0.15', 'kl': '0.07755', 'entropy': '1.927', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '38.17', 'epoch': '0.112'} +{'loss': '0.09177', 'grad_norm': '3.484', 'learning_rate': '5e-06', 'num_tokens': '2.743e+06', 'completions/mean_length': '36.12', 'completions/min_length': '9.3', 'completions/max_length': '207.7', 'completions/clipped_ratio': '0.0125', 'completions/mean_terminated_length': '33.34', 'completions/min_terminated_length': '9.3', 'completions/max_terminated_length': '146.1', 'rewards/reward_function/mean': '-0.3199', 'rewards/reward_function/std': '0.7088', 'reward': '-0.3199', 'reward_std': '0.7088', 'frac_reward_zero_std': '0.2062', 'kl': '0.06494', 'entropy': '1.696', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '36.11', 'epoch': '0.128'} +{'loss': '0.04645', 'grad_norm': '3.234', 'learning_rate': '5e-06', 'num_tokens': '3.082e+06', 'completions/mean_length': '33.89', 'completions/min_length': '9.6', 'completions/max_length': '180.4', 'completions/clipped_ratio': '0.01562', 'completions/mean_terminated_length': '30.4', 'completions/min_terminated_length': '9.6', 'completions/max_terminated_length': '111', 'rewards/reward_function/mean': '-0.2734', 'rewards/reward_function/std': '0.7237', 'reward': '-0.2734', 'reward_std': '0.7237', 'frac_reward_zero_std': '0.3063', 'kl': '0.06039', 'entropy': '1.521', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '34.24', 'epoch': '0.144'} +{'loss': '0.07534', 'grad_norm': '5.625', 'learning_rate': '5e-06', 'num_tokens': '3.443e+06', 'completions/mean_length': '30.04', 'completions/min_length': '8.2', 'completions/max_length': '137.4', 'completions/clipped_ratio': '0.003125', 'completions/mean_terminated_length': '29.33', 'completions/min_terminated_length': '8.2', 'completions/max_terminated_length': '110.4', 'rewards/reward_function/mean': '-0.2331', 'rewards/reward_function/std': '0.7244', 'reward': '-0.2331', 'reward_std': '0.7244', 'frac_reward_zero_std': '0.275', 'kl': '0.06758', 'entropy': '0.9546', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '32.34', 'epoch': '0.16'} + 5%|▌ | 101/1875 [1:01:52<25:11:09, 51.11s/it] 5%|▌ | 102/1875 [1:02:37<24:16:55, 49.30s/it] 5%|▌ | 103/1875 [1:03:05<21:02:34, 42.75s/it] 6%|▌ | 104/1875 [1:03:30<18:21:51, 37.33s/it] 6%|▌ | 105/1875 [1:04:02<17:36:02, 35.80s/it] 6%|▌ | 106/1875 [1:04:26<15:52:42, 32.31s/it] 6%|▌ | 107/1875 [1:05:10<17:33:04, 35.74s/it] 6%|▌ | 108/1875 [1:05:36<16:07:59, 32.87s/it] 6%|▌ | 109/1875 [1:06:13<16:41:11, 34.02s/it] 6%|▌ | 110/1875 [1:06:38<15:23:42, 31.40s/it] 6%|▌ | 110/1875 [1:06:38<15:23:42, 31.40s/it] 6%|▌ | 111/1875 [1:07:04<14:33:54, 29.72s/it] 6%|▌ | 112/1875 [1:07:29<13:50:44, 28.27s/it] 6%|▌ | 113/1875 [1:07:52<13:08:57, 26.87s/it] 6%|▌ | 114/1875 [1:08:18<13:01:50, 26.64s/it] 6%|▌ | 115/1875 [1:08:41<12:24:52, 25.39s/it] 6%|▌ | 116/1875 [1:09:04<12:08:08, 24.84s/it] 6%|▌ | 117/1875 [1:09:45<14:27:14, 29.60s/it] 6%|▋ | 118/1875 [1:10:08<13:25:13, 27.50s/it] 6%|▋ | 119/1875 [1:10:29<12:27:21, 25.54s/it] 6%|▋ | 120/1875 [1:11:12<15:02:48, 30.87s/it] 6%|▋ | 120/1875 [1:11:12<15:02:48, 30.87s/it] 6%|▋ | 121/1875 [1:11:33<13:35:00, 27.88s/it] 7%|▋ | 122/1875 [1:12:11<15:06:18, 31.02s/it] 7%|▋ | 123/1875 [1:12:40<14:50:46, 30.51s/it] 7%|▋ | 124/1875 [1:13:07<14:16:45, 29.36s/it] 7%|▋ | 125/1875 [1:13:32<13:38:29, 28.06s/it] 7%|▋ | 126/1875 [1:14:01<13:48:36, 28.43s/it] 7%|▋ | 127/1875 [1:14:32<14:06:52, 29.07s/it] 7%|▋ | 128/1875 [1:14:56<13:23:18, 27.59s/it] 7%|▋ | 129/1875 [1:15:21<13:01:56, 26.87s/it] 7%|▋ | 130/1875 [1:15:47<12:54:08, 26.62s/it] 7%|▋ | 130/1875 [1:15:47<12:54:08, 26.62s/it] 7%|▋ | 131/1875 [1:16:31<15:22:15, 31.73s/it] 7%|▋ | 132/1875 [1:17:17<17:22:52, 35.90s/it] 7%|▋ | 133/1875 [1:17:41<15:37:41, 32.30s/it] 7%|▋ | 134/1875 [1:18:07<14:49:06, 30.64s/it] 7%|▋ | 135/1875 [1:18:47<16:06:50, 33.34s/it] 7%|▋ | 136/1875 [1:19:12<14:56:48, 30.94s/it] 7%|▋ | 137/1875 [1:19:38<14:06:36, 29.23s/it] 7%|▋ | 138/1875 [1:20:23<16:27:50, 34.12s/it] 7%|▋ | 139/1875 [1:20:46<14:51:44, 30.82s/it] 7%|▋ | 140/1875 [1:21:22<15:33:11, 32.27s/it] 7%|▋ | 140/1875 [1:21:22<15:33:11, 32.27s/it] 8%|▊ | 141/1875 [1:21:43<13:58:07, 29.00s/it] 8%|▊ | 142/1875 [1:22:06<13:07:17, 27.26s/it] 8%|▊ | 143/1875 [1:22:47<15:01:08, 31.22s/it] 8%|▊ | 144/1875 [1:23:22<15:32:53, 32.34s/it] 8%|▊ | 145/1875 [1:23:47<14:33:27, 30.29s/it] 8%|▊ | 146/1875 [1:24:13<13:57:22, 29.06s/it] 8%|▊ | 147/1875 [1:24:38<13:17:28, 27.69s/it] 8%|▊ | 148/1875 [1:25:03<12:52:32, 26.84s/it] 8%|▊ | 149/1875 [1:25:26<12:22:17, 25.80s/it] 8%|▊ | 150/1875 [1:26:08<14:42:29, 30.70s/it] 8%|▊ | 150/1875 [1:26:08<14:42:29, 30.70s/it] 8%|▊ | 151/1875 [1:26:35<14:04:14, 29.38s/it] 8%|▊ | 152/1875 [1:27:15<15:34:45, 32.55s/it] 8%|▊ | 153/1875 [1:27:54<16:31:15, 34.54s/it] 8%|▊ | 154/1875 [1:28:16<14:41:41, 30.74s/it] 8%|▊ | 155/1875 [1:28:45<14:28:16, 30.29s/it] 8%|▊ | 156/1875 [1:29:17<14:39:24, 30.69s/it] 8%|▊ | 157/1875 [1:29:39<13:25:04, 28.12s/it] 8%|▊ | 158/1875 [1:30:01<12:35:19, 26.39s/it] 8%|▊ | 159/1875 [1:30:25<12:17:22, 25.78s/it] 9%|▊ | 160/1875 [1:30:48<11:51:11, 24.88s/it] 9%|▊ | 160/1875 [1:30:48<11:51:11, 24.88s/it] 9%|▊ | 161/1875 [1:31:13<11:48:48, 24.81s/it] 9%|▊ | 162/1875 [1:31:37<11:39:40, 24.51s/it] 9%|▊ | 163/1875 [1:31:59<11:21:55, 23.90s/it] 9%|▊ | 164/1875 [1:32:21<11:00:56, 23.18s/it] 9%|▉ | 165/1875 [1:33:04<13:56:40, 29.36s/it] 9%|▉ | 166/1875 [1:33:32<13:40:39, 28.81s/it] 9%|▉ | 167/1875 [1:33:55<12:51:02, 27.09s/it] 9%|▉ | 168/1875 [1:34:20<12:34:10, 26.51s/it] 9%|▉ | 169/1875 [1:34:44<12:11:54, 25.74s/it] 9%|▉ | 170/1875 [1:35:07<11:48:49, 24.94s/it] 9%|▉ | 170/1875 [1:35:07<11:48:49, 24.94s/it] 9%|▉ | 171/1875 [1:35:32<11:45:43, 24.85s/it] 9%|▉ | 172/1875 [1:36:01<12:24:04, 26.21s/it] 9%|▉ | 173/1875 [1:36:41<14:19:34, 30.30s/it] 9%|▉ | 174/1875 [1:37:02<12:57:28, 27.42s/it] 9%|▉ | 175/1875 [1:37:27<12:37:29, 26.74s/it] 9%|▉ | 176/1875 [1:37:48<11:51:48, 25.14s/it] 9%|▉ | 177/1875 [1:38:15<12:05:44, 25.64s/it] 9%|▉ | 178/1875 [1:38:58<14:31:33, 30.82s/it] 10%|▉ | 179/1875 [1:39:20<13:12:50, 28.05s/it] 10%|▉ | 180/1875 [1:39:44<12:38:01, 26.83s/it] 10%|▉ | 180/1875 [1:39:44<12:38:01, 26.83s/it] 10%|▉ | 181/1875 [1:40:07<12:08:36, 25.81s/it] 10%|▉ | 182/1875 [1:40:30<11:45:04, 24.99s/it] 10%|▉ | 183/1875 [1:40:51<11:09:53, 23.76s/it] 10%|▉ | 184/1875 [1:41:13<10:53:07, 23.17s/it] 10%|▉ | 185/1875 [1:41:41<11:36:32, 24.73s/it] 10%|▉ | 186/1875 [1:42:03<11:16:07, 24.02s/it] 10%|▉ | 187/1875 [1:42:28<11:19:35, 24.16s/it] 10%|█ | 188/1875 [1:42:51<11:06:57, 23.72s/it] 10%|█ | 189/1875 [1:43:16<11:22:12, 24.28s/it] 10%|█ | 190/1875 [1:43:37<10:56:21, 23.37s/it] 10%|█ | 190/1875 [1:43:37<10:56:21, 23.37s/it] 10%|█ | 191/1875 [1:44:17<13:09:48, 28.14s/it] 10%|█ | 192/1875 [1:44:42<12:44:50, 27.27s/it] 10%|█ | 193/1875 [1:45:08<12:34:43, 26.92s/it] 10%|█ | 194/1875 [1:45:31<12:02:56, 25.80s/it] 10%|█ | 195/1875 [1:45:56<11:57:33, 25.63s/it] 10%|█ | 196/1875 [1:46:18<11:18:46, 24.26s/it] 11%|█ | 197/1875 [1:46:56<13:14:11, 28.40s/it] 11%|█ | 198/1875 [1:47:20<12:42:20, 27.28s/it] 11%|█ | 199/1875 [1:47:42<11:54:06, 25.56s/it] 11%|█ | 200/1875 [1:48:04<11:29:02, 24.68s/it] 11%|█ | 200/1875 [1:48:04<11:29:02, 24.68s/it]{'loss': '0.04026', 'grad_norm': '6.75', 'learning_rate': '5e-06', 'num_tokens': '3.786e+06', 'completions/mean_length': '28.02', 'completions/min_length': '9.5', 'completions/max_length': '137.8', 'completions/clipped_ratio': '0.003125', 'completions/mean_terminated_length': '27.31', 'completions/min_terminated_length': '9.5', 'completions/max_terminated_length': '105.2', 'rewards/reward_function/mean': '-0.294', 'rewards/reward_function/std': '0.7176', 'reward': '-0.294', 'reward_std': '0.7176', 'frac_reward_zero_std': '0.325', 'kl': '0.07931', 'entropy': '0.8174', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '31.08', 'epoch': '0.176'} +{'loss': '0.04267', 'grad_norm': '4.625', 'learning_rate': '5e-06', 'num_tokens': '4.129e+06', 'completions/mean_length': '25.05', 'completions/min_length': '9.3', 'completions/max_length': '101.4', 'completions/clipped_ratio': '0.003125', 'completions/mean_terminated_length': '24.33', 'completions/min_terminated_length': '9.3', 'completions/max_terminated_length': '63.9', 'rewards/reward_function/mean': '-0.1721', 'rewards/reward_function/std': '0.7501', 'reward': '-0.1721', 'reward_std': '0.7501', 'frac_reward_zero_std': '0.375', 'kl': '0.08927', 'entropy': '0.6128', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '27.31', 'epoch': '0.192'} +{'loss': '0.03885', 'grad_norm': '6.469', 'learning_rate': '5e-06', 'num_tokens': '4.458e+06', 'completions/mean_length': '25.95', 'completions/min_length': '8.3', 'completions/max_length': '99.8', 'completions/clipped_ratio': '0.001563', 'completions/mean_terminated_length': '25.6', 'completions/min_terminated_length': '8.3', 'completions/max_terminated_length': '81.9', 'rewards/reward_function/mean': '-0.2468', 'rewards/reward_function/std': '0.7458', 'reward': '-0.2468', 'reward_std': '0.7458', 'frac_reward_zero_std': '0.375', 'kl': '0.0872', 'entropy': '0.555', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '27.47', 'epoch': '0.208'} +{'loss': '0.04248', 'grad_norm': '10.5', 'learning_rate': '5e-06', 'num_tokens': '4.817e+06', 'completions/mean_length': '27.59', 'completions/min_length': '8.4', 'completions/max_length': '155.9', 'completions/clipped_ratio': '0.007812', 'completions/mean_terminated_length': '25.79', 'completions/min_terminated_length': '8.4', 'completions/max_terminated_length': '106.1', 'rewards/reward_function/mean': '-0.2004', 'rewards/reward_function/std': '0.7367', 'reward': '-0.2004', 'reward_std': '0.7367', 'frac_reward_zero_std': '0.3875', 'kl': '0.08259', 'entropy': '1.168', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '33.36', 'epoch': '0.224'} +{'loss': '0.04001', 'grad_norm': '7.062', 'learning_rate': '5e-06', 'num_tokens': '5.16e+06', 'completions/mean_length': '24.71', 'completions/min_length': '7.8', 'completions/max_length': '119', 'completions/clipped_ratio': '0.003125', 'completions/mean_terminated_length': '23.98', 'completions/min_terminated_length': '7.8', 'completions/max_terminated_length': '80.5', 'rewards/reward_function/mean': '-0.164', 'rewards/reward_function/std': '0.751', 'reward': '-0.164', 'reward_std': '0.751', 'frac_reward_zero_std': '0.425', 'kl': '0.08499', 'entropy': '0.7907', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '28.56', 'epoch': '0.24'} +{'loss': '0.04599', 'grad_norm': '6.688', 'learning_rate': '5e-06', 'num_tokens': '5.496e+06', 'completions/mean_length': '25.46', 'completions/min_length': '7.2', 'completions/max_length': '120.3', 'completions/clipped_ratio': '0.001563', 'completions/mean_terminated_length': '25.1', 'completions/min_terminated_length': '7.2', 'completions/max_terminated_length': '104.3', 'rewards/reward_function/mean': '-0.2934', 'rewards/reward_function/std': '0.7252', 'reward': '-0.2934', 'reward_std': '0.7252', 'frac_reward_zero_std': '0.3812', 'kl': '0.08188', 'entropy': '0.7384', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '27.9', 'epoch': '0.256'} +{'loss': '0.05606', 'grad_norm': '5.781', 'learning_rate': '5e-06', 'num_tokens': '5.818e+06', 'completions/mean_length': '23.48', 'completions/min_length': '7.7', 'completions/max_length': '94.8', 'completions/clipped_ratio': '0.001563', 'completions/mean_terminated_length': '23.11', 'completions/min_terminated_length': '7.7', 'completions/max_terminated_length': '74.9', 'rewards/reward_function/mean': '-0.1709', 'rewards/reward_function/std': '0.7138', 'reward': '-0.1709', 'reward_std': '0.7138', 'frac_reward_zero_std': '0.4188', 'kl': '0.08449', 'entropy': '0.6329', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '25.82', 'epoch': '0.272'} +{'loss': '0.03233', 'grad_norm': '4.562', 'learning_rate': '5e-06', 'num_tokens': '6.158e+06', 'completions/mean_length': '24.27', 'completions/min_length': '7.3', 'completions/max_length': '105.1', 'completions/clipped_ratio': '0.001563', 'completions/mean_terminated_length': '23.91', 'completions/min_terminated_length': '7.3', 'completions/max_terminated_length': '88', 'rewards/reward_function/mean': '-0.2321', 'rewards/reward_function/std': '0.7298', 'reward': '-0.2321', 'reward_std': '0.7298', 'frac_reward_zero_std': '0.4625', 'kl': '0.1018', 'entropy': '0.626', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '27.56', 'epoch': '0.288'} +{'loss': '0.02813', 'grad_norm': '6', 'learning_rate': '5e-06', 'num_tokens': '6.495e+06', 'completions/mean_length': '21.37', 'completions/min_length': '7.1', 'completions/max_length': '66.7', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '21.37', 'completions/min_terminated_length': '7.1', 'completions/max_terminated_length': '66.7', 'rewards/reward_function/mean': '-0.1516', 'rewards/reward_function/std': '0.7251', 'reward': '-0.1516', 'reward_std': '0.7251', 'frac_reward_zero_std': '0.4562', 'kl': '0.09969', 'entropy': '0.3524', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '23.31', 'epoch': '0.304'} +{'loss': '0.01319', 'grad_norm': '9.688', 'learning_rate': '5e-06', 'num_tokens': '6.822e+06', 'completions/mean_length': '23.63', 'completions/min_length': '7.1', 'completions/max_length': '110.4', 'completions/clipped_ratio': '0.003125', 'completions/mean_terminated_length': '22.89', 'completions/min_terminated_length': '7.1', 'completions/max_terminated_length': '77', 'rewards/reward_function/mean': '-0.2271', 'rewards/reward_function/std': '0.7689', 'reward': '-0.2271', 'reward_std': '0.7689', 'frac_reward_zero_std': '0.45', 'kl': '0.07928', 'entropy': '0.792', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '26.62', 'epoch': '0.32'} + 11%|█ | 201/1875 [1:49:37<20:55:02, 44.98s/it] 11%|█ | 202/1875 [1:49:59<17:39:32, 38.00s/it] 11%|█ | 203/1875 [1:50:19<15:10:31, 32.67s/it] 11%|█ | 204/1875 [1:51:04<16:53:57, 36.41s/it] 11%|█ | 205/1875 [1:51:26<14:51:47, 32.04s/it] 11%|█ | 206/1875 [1:51:54<14:16:44, 30.80s/it] 11%|█ | 207/1875 [1:52:17<13:16:27, 28.65s/it] 11%|█ | 208/1875 [1:52:45<13:08:12, 28.37s/it] 11%|█ | 209/1875 [1:53:25<14:40:47, 31.72s/it] 11%|█ | 210/1875 [1:53:46<13:14:56, 28.65s/it] 11%|█ | 210/1875 [1:53:46<13:14:56, 28.65s/it] 11%|█▏ | 211/1875 [1:54:06<12:03:31, 26.09s/it] 11%|█▏ | 212/1875 [1:54:29<11:33:08, 25.01s/it] 11%|█▏ | 213/1875 [1:54:52<11:15:39, 24.39s/it] 11%|█▏ | 214/1875 [1:55:31<13:17:56, 28.82s/it] 11%|█▏ | 215/1875 [1:55:57<12:53:39, 27.96s/it] 12%|█▏ | 216/1875 [1:56:36<14:24:31, 31.27s/it] 12%|█▏ | 217/1875 [1:57:16<15:38:31, 33.96s/it] 12%|█▏ | 218/1875 [1:57:43<14:41:41, 31.93s/it] 12%|█▏ | 219/1875 [1:58:03<13:03:19, 28.38s/it] 12%|█▏ | 220/1875 [1:58:24<11:57:48, 26.02s/it] 12%|█▏ | 220/1875 [1:58:24<11:57:48, 26.02s/it] 12%|█▏ | 221/1875 [1:58:48<11:42:45, 25.49s/it] 12%|█▏ | 222/1875 [1:59:10<11:09:50, 24.31s/it] 12%|█▏ | 223/1875 [1:59:35<11:21:14, 24.74s/it] 12%|█▏ | 224/1875 [1:59:59<11:08:30, 24.29s/it] 12%|█▏ | 225/1875 [2:00:20<10:45:17, 23.46s/it] 12%|█▏ | 226/1875 [2:00:45<10:56:32, 23.89s/it] 12%|█▏ | 227/1875 [2:01:06<10:36:14, 23.16s/it] 12%|█▏ | 228/1875 [2:01:29<10:27:59, 22.88s/it] 12%|█▏ | 229/1875 [2:01:48<10:01:50, 21.94s/it] 12%|█▏ | 230/1875 [2:02:19<11:12:37, 24.53s/it] 12%|█▏ | 230/1875 [2:02:19<11:12:37, 24.53s/it] 12%|█▏ | 231/1875 [2:02:46<11:36:04, 25.40s/it] 12%|█▏ | 232/1875 [2:03:11<11:25:33, 25.04s/it] 12%|█▏ | 233/1875 [2:03:33<11:06:02, 24.34s/it] 12%|█▏ | 234/1875 [2:03:56<10:50:09, 23.77s/it] 13%|█▎ | 235/1875 [2:04:17<10:27:48, 22.97s/it] 13%|█▎ | 236/1875 [2:04:41<10:39:57, 23.43s/it] 13%|█▎ | 237/1875 [2:05:06<10:47:09, 23.71s/it] 13%|█▎ | 238/1875 [2:05:32<11:05:20, 24.39s/it] 13%|█▎ | 239/1875 [2:06:01<11:41:49, 25.74s/it] 13%|█▎ | 240/1875 [2:06:27<11:49:51, 26.05s/it] 13%|█▎ | 240/1875 [2:06:27<11:49:51, 26.05s/it] 13%|█▎ | 241/1875 [2:06:52<11:36:53, 25.59s/it] 13%|█▎ | 242/1875 [2:07:18<11:38:28, 25.66s/it] 13%|█▎ | 243/1875 [2:07:44<11:46:35, 25.98s/it] 13%|█▎ | 244/1875 [2:08:06<11:13:22, 24.77s/it] 13%|█▎ | 245/1875 [2:08:30<11:02:04, 24.37s/it] 13%|█▎ | 246/1875 [2:08:55<11:10:37, 24.70s/it] 13%|█▎ | 247/1875 [2:09:16<10:39:19, 23.56s/it] 13%|█▎ | 248/1875 [2:09:36<10:08:41, 22.45s/it] 13%|█▎ | 249/1875 [2:09:57<10:00:07, 22.15s/it] 13%|█▎ | 250/1875 [2:10:16<9:32:49, 21.15s/it] 13%|█▎ | 250/1875 [2:10:16<9:32:49, 21.15s/it] 13%|█▎ | 251/1875 [2:10:37<9:27:29, 20.97s/it] 13%|█▎ | 252/1875 [2:11:02<10:00:49, 22.21s/it] 13%|█▎ | 253/1875 [2:11:23<9:55:04, 22.01s/it] 14%|█▎ | 254/1875 [2:11:52<10:45:40, 23.90s/it] 14%|█▎ | 255/1875 [2:12:19<11:12:46, 24.92s/it] 14%|█▎ | 256/1875 [2:12:44<11:10:08, 24.84s/it] 14%|█▎ | 257/1875 [2:13:06<10:52:39, 24.20s/it] 14%|█▍ | 258/1875 [2:13:51<13:35:21, 30.25s/it] 14%|█▍ | 259/1875 [2:14:13<12:26:23, 27.71s/it] 14%|█▍ | 260/1875 [2:14:36<11:51:02, 26.42s/it] 14%|█▍ | 260/1875 [2:14:36<11:51:02, 26.42s/it] 14%|█▍ | 261/1875 [2:14:57<11:05:40, 24.75s/it] 14%|█▍ | 262/1875 [2:15:20<10:48:44, 24.13s/it] 14%|█▍ | 263/1875 [2:15:45<10:55:26, 24.40s/it] 14%|█▍ | 264/1875 [2:16:10<11:01:30, 24.64s/it] 14%|█▍ | 265/1875 [2:16:35<11:04:09, 24.75s/it] 14%|█▍ | 266/1875 [2:16:56<10:36:01, 23.72s/it] 14%|█▍ | 267/1875 [2:17:21<10:42:26, 23.97s/it] 14%|█▍ | 268/1875 [2:17:49<11:14:54, 25.20s/it] 14%|█▍ | 269/1875 [2:18:13<11:08:12, 24.96s/it] 14%|█▍ | 270/1875 [2:18:35<10:46:23, 24.16s/it] 14%|█▍ | 270/1875 [2:18:35<10:46:23, 24.16s/it] 14%|█▍ | 271/1875 [2:19:03<11:10:31, 25.08s/it] 15%|█▍ | 272/1875 [2:19:26<10:54:00, 24.48s/it] 15%|█▍ | 273/1875 [2:19:48<10:37:43, 23.89s/it] 15%|█▍ | 274/1875 [2:20:10<10:18:59, 23.20s/it] 15%|█▍ | 275/1875 [2:20:35<10:38:15, 23.93s/it] 15%|█▍ | 276/1875 [2:21:01<10:52:40, 24.49s/it] 15%|█▍ | 277/1875 [2:21:24<10:39:32, 24.01s/it] 15%|█▍ | 278/1875 [2:21:55<11:36:13, 26.16s/it] 15%|█▍ | 279/1875 [2:22:15<10:47:05, 24.33s/it] 15%|█▍ | 280/1875 [2:22:36<10:20:08, 23.33s/it] 15%|█▍ | 280/1875 [2:22:36<10:20:08, 23.33s/it] 15%|█▍ | 281/1875 [2:22:57<10:01:21, 22.64s/it] 15%|█▌ | 282/1875 [2:23:21<10:12:05, 23.05s/it] 15%|█▌ | 283/1875 [2:23:44<10:10:18, 23.00s/it] 15%|█▌ | 284/1875 [2:24:04<9:46:32, 22.12s/it] 15%|█▌ | 285/1875 [2:24:25<9:38:19, 21.82s/it] 15%|█▌ | 286/1875 [2:24:45<9:20:24, 21.16s/it] 15%|█▌ | 287/1875 [2:25:06<9:16:53, 21.04s/it] 15%|█▌ | 288/1875 [2:25:24<8:55:03, 20.23s/it] 15%|█▌ | 289/1875 [2:25:50<9:35:15, 21.76s/it] 15%|█▌ | 290/1875 [2:26:15<10:01:16, 22.76s/it] 15%|█▌ | 290/1875 [2:26:15<10:01:16, 22.76s/it] 16%|█▌ | 291/1875 [2:26:39<10:16:09, 23.34s/it] 16%|█▌ | 292/1875 [2:27:00<9:53:08, 22.48s/it] 16%|█▌ | 293/1875 [2:27:22<9:47:40, 22.29s/it] 16%|█▌ | 294/1875 [2:27:45<9:56:09, 22.62s/it] 16%|█▌ | 295/1875 [2:28:08<10:01:54, 22.86s/it] 16%|█▌ | 296/1875 [2:28:35<10:32:27, 24.03s/it] 16%|█▌ | 297/1875 [2:28:58<10:24:39, 23.75s/it] 16%|█▌ | 298/1875 [2:29:19<9:59:27, 22.81s/it] 16%|█▌ | 299/1875 [2:29:45<10:24:11, 23.76s/it] 16%|█▌ | 300/1875 [2:30:02<9:34:04, 21.87s/it] 16%|█▌ | 300/1875 [2:30:02<9:34:04, 21.87s/it]{'loss': '0.03715', 'grad_norm': '7.062', 'learning_rate': '5e-06', 'num_tokens': '7.15e+06', 'completions/mean_length': '22.55', 'completions/min_length': '7.2', 'completions/max_length': '107.6', 'completions/clipped_ratio': '0.003125', 'completions/mean_terminated_length': '21.82', 'completions/min_terminated_length': '7.2', 'completions/max_terminated_length': '70.6', 'rewards/reward_function/mean': '-0.1597', 'rewards/reward_function/std': '0.7546', 'reward': '-0.1597', 'reward_std': '0.7546', 'frac_reward_zero_std': '0.4125', 'kl': '0.1043', 'entropy': '0.704', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '27.5', 'epoch': '0.336'} +{'loss': '0.01712', 'grad_norm': '5.25', 'learning_rate': '5e-06', 'num_tokens': '7.487e+06', 'completions/mean_length': '19.39', 'completions/min_length': '7', 'completions/max_length': '114.5', 'completions/clipped_ratio': '0.004687', 'completions/mean_terminated_length': '18.28', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '56', 'rewards/reward_function/mean': '-0.07466', 'rewards/reward_function/std': '0.7352', 'reward': '-0.07466', 'reward_std': '0.7352', 'frac_reward_zero_std': '0.4562', 'kl': '0.1034', 'entropy': '0.762', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '27.69', 'epoch': '0.352'} +{'loss': '0.03657', 'grad_norm': '10.38', 'learning_rate': '5e-06', 'num_tokens': '7.822e+06', 'completions/mean_length': '19.63', 'completions/min_length': '7.1', 'completions/max_length': '67.2', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '19.63', 'completions/min_terminated_length': '7.1', 'completions/max_terminated_length': '67.2', 'rewards/reward_function/mean': '-0.1773', 'rewards/reward_function/std': '0.7373', 'reward': '-0.1773', 'reward_std': '0.7373', 'frac_reward_zero_std': '0.4625', 'kl': '0.118', 'entropy': '0.3687', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '23.44', 'epoch': '0.368'} +{'loss': '0.03408', 'grad_norm': '7.906', 'learning_rate': '5e-06', 'num_tokens': '8.184e+06', 'completions/mean_length': '20.55', 'completions/min_length': '7.1', 'completions/max_length': '61.9', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '20.55', 'completions/min_terminated_length': '7.1', 'completions/max_terminated_length': '61.9', 'rewards/reward_function/mean': '-0.1846', 'rewards/reward_function/std': '0.7266', 'reward': '-0.1846', 'reward_std': '0.7266', 'frac_reward_zero_std': '0.4938', 'kl': '0.1034', 'entropy': '0.3668', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '24.75', 'epoch': '0.384'} +{'loss': '0.0276', 'grad_norm': '3.766', 'learning_rate': '5e-06', 'num_tokens': '8.52e+06', 'completions/mean_length': '20.02', 'completions/min_length': '7', 'completions/max_length': '63.9', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '20.02', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '63.9', 'rewards/reward_function/mean': '-0.1902', 'rewards/reward_function/std': '0.7268', 'reward': '-0.1902', 'reward_std': '0.7268', 'frac_reward_zero_std': '0.4625', 'kl': '0.108', 'entropy': '0.3885', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '22.81', 'epoch': '0.4'} +{'loss': '0.04853', 'grad_norm': '12.44', 'learning_rate': '5e-06', 'num_tokens': '8.86e+06', 'completions/mean_length': '19.78', 'completions/min_length': '7', 'completions/max_length': '88.2', 'completions/clipped_ratio': '0.001563', 'completions/mean_terminated_length': '19.41', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '71.6', 'rewards/reward_function/mean': '-0.01014', 'rewards/reward_function/std': '0.7326', 'reward': '-0.01014', 'reward_std': '0.7326', 'frac_reward_zero_std': '0.4188', 'kl': '0.1068', 'entropy': '0.5668', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '25.89', 'epoch': '0.416'} +{'loss': '0.05604', 'grad_norm': '5.188', 'learning_rate': '5e-06', 'num_tokens': '9.196e+06', 'completions/mean_length': '20.62', 'completions/min_length': '7', 'completions/max_length': '64', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '20.62', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '64', 'rewards/reward_function/mean': '-0.1506', 'rewards/reward_function/std': '0.7477', 'reward': '-0.1506', 'reward_std': '0.7477', 'frac_reward_zero_std': '0.5125', 'kl': '0.1054', 'entropy': '0.3568', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '23.86', 'epoch': '0.432'} +{'loss': '0.05192', 'grad_norm': '7.25', 'learning_rate': '5e-06', 'num_tokens': '9.508e+06', 'completions/mean_length': '18.59', 'completions/min_length': '7', 'completions/max_length': '83.1', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '18.59', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '83.1', 'rewards/reward_function/mean': '-0.1603', 'rewards/reward_function/std': '0.7717', 'reward': '-0.1603', 'reward_std': '0.7717', 'frac_reward_zero_std': '0.5188', 'kl': '0.1193', 'entropy': '0.499', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '24.01', 'epoch': '0.448'} +{'loss': '0.04143', 'grad_norm': '34.5', 'learning_rate': '5e-06', 'num_tokens': '9.827e+06', 'completions/mean_length': '16.19', 'completions/min_length': '7', 'completions/max_length': '53.9', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '16.19', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '53.9', 'rewards/reward_function/mean': '-0.0907', 'rewards/reward_function/std': '0.7427', 'reward': '-0.0907', 'reward_std': '0.7427', 'frac_reward_zero_std': '0.5625', 'kl': '0.155', 'entropy': '0.2627', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '21.75', 'epoch': '0.464'} +{'loss': '0.01211', 'grad_norm': '6.656', 'learning_rate': '5e-06', 'num_tokens': '1.015e+07', 'completions/mean_length': '15.18', 'completions/min_length': '7', 'completions/max_length': '50.2', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '15.18', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '50.2', 'rewards/reward_function/mean': '-0.07846', 'rewards/reward_function/std': '0.7477', 'reward': '-0.07846', 'reward_std': '0.7477', 'frac_reward_zero_std': '0.5188', 'kl': '0.171', 'entropy': '0.2525', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '22.69', 'epoch': '0.48'} + 16%|█▌ | 301/1875 [2:31:33<18:33:11, 42.43s/it] 16%|█▌ | 302/1875 [2:31:53<15:41:35, 35.92s/it] 16%|█▌ | 303/1875 [2:32:13<13:29:24, 30.89s/it] 16%|█▌ | 304/1875 [2:32:36<12:31:42, 28.71s/it] 16%|█▋ | 305/1875 [2:33:02<12:09:25, 27.88s/it] 16%|█▋ | 306/1875 [2:33:22<11:09:25, 25.60s/it] 16%|█▋ | 307/1875 [2:33:41<10:13:06, 23.46s/it] 16%|█▋ | 308/1875 [2:34:01<9:44:00, 22.36s/it] 16%|█▋ | 309/1875 [2:34:22<9:36:21, 22.08s/it] 17%|█▋ | 310/1875 [2:34:45<9:40:29, 22.26s/it] 17%|█▋ | 310/1875 [2:34:45<9:40:29, 22.26s/it] 17%|█▋ | 311/1875 [2:35:12<10:20:25, 23.80s/it] 17%|█▋ | 312/1875 [2:35:35<10:09:09, 23.38s/it] 17%|█▋ | 313/1875 [2:35:57<10:03:47, 23.19s/it] 17%|█▋ | 314/1875 [2:36:22<10:15:53, 23.67s/it] 17%|█▋ | 315/1875 [2:36:48<10:32:58, 24.35s/it] 17%|█▋ | 316/1875 [2:37:11<10:22:55, 23.97s/it] 17%|█▋ | 317/1875 [2:37:33<10:08:35, 23.44s/it] 17%|█▋ | 318/1875 [2:37:56<9:59:01, 23.08s/it] 17%|█▋ | 319/1875 [2:38:22<10:25:38, 24.13s/it] 17%|█▋ | 320/1875 [2:38:48<10:38:08, 24.62s/it] 17%|█▋ | 320/1875 [2:38:48<10:38:08, 24.62s/it] 17%|█▋ | 321/1875 [2:39:15<10:53:13, 25.22s/it] 17%|█▋ | 322/1875 [2:39:36<10:22:10, 24.04s/it] 17%|█▋ | 323/1875 [2:40:00<10:22:45, 24.08s/it] 17%|█▋ | 324/1875 [2:40:27<10:42:26, 24.85s/it] 17%|█▋ | 325/1875 [2:40:52<10:41:36, 24.84s/it] 17%|█▋ | 326/1875 [2:41:11<9:58:17, 23.17s/it] 17%|█▋ | 327/1875 [2:41:36<10:10:19, 23.66s/it] 17%|█▋ | 328/1875 [2:41:59<10:08:05, 23.59s/it] 18%|█▊ | 329/1875 [2:42:19<9:38:59, 22.47s/it] 18%|█▊ | 330/1875 [2:42:39<9:18:42, 21.70s/it] 18%|█▊ | 330/1875 [2:42:39<9:18:42, 21.70s/it] 18%|█▊ | 331/1875 [2:43:00<9:16:54, 21.64s/it] 18%|█▊ | 332/1875 [2:43:22<9:14:22, 21.56s/it] 18%|█▊ | 333/1875 [2:43:47<9:41:14, 22.62s/it] 18%|█▊ | 334/1875 [2:44:08<9:29:37, 22.18s/it] 18%|█▊ | 335/1875 [2:44:33<9:50:50, 23.02s/it] 18%|█▊ | 336/1875 [2:44:58<10:04:31, 23.57s/it] 18%|█▊ | 337/1875 [2:45:22<10:10:15, 23.81s/it] 18%|█▊ | 338/1875 [2:45:46<10:14:10, 23.98s/it] 18%|█▊ | 339/1875 [2:46:11<10:17:33, 24.12s/it] 18%|█▊ | 340/1875 [2:46:32<9:51:45, 23.13s/it] 18%|█▊ | 340/1875 [2:46:32<9:51:45, 23.13s/it] 18%|█▊ | 341/1875 [2:47:15<12:27:47, 29.25s/it] 18%|█▊ | 342/1875 [2:47:39<11:41:36, 27.46s/it] 18%|█▊ | 343/1875 [2:48:00<10:58:24, 25.79s/it] 18%|█▊ | 344/1875 [2:48:21<10:17:13, 24.19s/it] 18%|█▊ | 345/1875 [2:48:46<10:20:49, 24.35s/it] 18%|█▊ | 346/1875 [2:49:07<9:57:49, 23.46s/it] 19%|█▊ | 347/1875 [2:49:27<9:34:18, 22.55s/it] 19%|█▊ | 348/1875 [2:49:51<9:37:54, 22.71s/it] 19%|█▊ | 349/1875 [2:50:14<9:45:23, 23.02s/it] 19%|█▊ | 350/1875 [2:50:42<10:19:10, 24.36s/it] 19%|█▊ | 350/1875 [2:50:42<10:19:10, 24.36s/it] 19%|█▊ | 351/1875 [2:51:05<10:06:48, 23.89s/it] 19%|█▉ | 352/1875 [2:51:25<9:41:22, 22.90s/it] 19%|█▉ | 353/1875 [2:51:48<9:43:39, 23.01s/it] 19%|█▉ | 354/1875 [2:52:11<9:40:17, 22.89s/it] 19%|█▉ | 355/1875 [2:52:28<8:55:42, 21.15s/it] 19%|█▉ | 356/1875 [2:52:51<9:07:34, 21.63s/it] 19%|█▉ | 357/1875 [2:53:11<8:54:34, 21.13s/it] 19%|█▉ | 358/1875 [2:53:34<9:11:36, 21.82s/it] 19%|█▉ | 359/1875 [2:53:53<8:48:49, 20.93s/it] 19%|█▉ | 360/1875 [2:54:14<8:48:40, 20.94s/it] 19%|█▉ | 360/1875 [2:54:14<8:48:40, 20.94s/it] 19%|█▉ | 361/1875 [2:54:40<9:27:34, 22.49s/it] 19%|█▉ | 362/1875 [2:55:05<9:48:28, 23.34s/it] 19%|█▉ | 363/1875 [2:55:30<9:58:49, 23.76s/it] 19%|█▉ | 364/1875 [2:55:51<9:32:11, 22.72s/it] 19%|█▉ | 365/1875 [2:56:11<9:13:22, 21.99s/it] 20%|█▉ | 366/1875 [2:56:36<9:37:31, 22.96s/it] 20%|█▉ | 367/1875 [2:56:57<9:22:43, 22.39s/it] 20%|█▉ | 368/1875 [2:57:20<9:24:00, 22.46s/it] 20%|█▉ | 369/1875 [2:57:39<8:57:05, 21.40s/it] 20%|█▉ | 370/1875 [2:58:05<9:33:54, 22.88s/it] 20%|█▉ | 370/1875 [2:58:05<9:33:54, 22.88s/it] 20%|█▉ | 371/1875 [2:58:24<9:02:54, 21.66s/it] 20%|█▉ | 372/1875 [2:58:43<8:44:31, 20.94s/it] 20%|█▉ | 373/1875 [2:59:01<8:20:08, 19.98s/it] 20%|█▉ | 374/1875 [2:59:25<8:49:02, 21.15s/it] 20%|██ | 375/1875 [2:59:45<8:40:20, 20.81s/it] 20%|██ | 376/1875 [3:00:03<8:23:52, 20.17s/it] 20%|██ | 377/1875 [3:00:28<8:54:43, 21.42s/it] 20%|██ | 378/1875 [3:00:51<9:08:53, 22.00s/it] 20%|██ | 379/1875 [3:01:35<11:51:04, 28.52s/it] 20%|██ | 380/1875 [3:01:55<10:49:06, 26.05s/it] 20%|██ | 380/1875 [3:01:55<10:49:06, 26.05s/it] 20%|██ | 381/1875 [3:02:18<10:28:03, 25.22s/it] 20%|██ | 382/1875 [3:02:43<10:19:51, 24.91s/it] 20%|██ | 383/1875 [3:03:04<9:50:28, 23.75s/it] 20%|██ | 384/1875 [3:03:27<9:50:12, 23.75s/it] 21%|██ | 385/1875 [3:03:47<9:18:17, 22.48s/it] 21%|██ | 386/1875 [3:04:10<9:23:38, 22.71s/it] 21%|██ | 387/1875 [3:04:31<9:13:31, 22.32s/it] 21%|██ | 388/1875 [3:04:56<9:31:22, 23.06s/it] 21%|██ | 389/1875 [3:05:24<10:05:36, 24.45s/it] 21%|██ | 390/1875 [3:05:49<10:07:12, 24.53s/it] 21%|██ | 390/1875 [3:05:49<10:07:12, 24.53s/it] 21%|██ | 391/1875 [3:06:12<10:00:33, 24.28s/it] 21%|██ | 392/1875 [3:06:31<9:21:00, 22.70s/it] 21%|██ | 393/1875 [3:06:53<9:11:57, 22.35s/it] 21%|██ | 394/1875 [3:07:18<9:31:50, 23.17s/it] 21%|██ | 395/1875 [3:07:41<9:33:08, 23.24s/it] 21%|██ | 396/1875 [3:08:05<9:32:17, 23.22s/it] 21%|██ | 397/1875 [3:08:23<8:58:11, 21.85s/it] 21%|██ | 398/1875 [3:08:43<8:42:13, 21.21s/it] 21%|██▏ | 399/1875 [3:09:02<8:25:33, 20.55s/it] 21%|██▏ | 400/1875 [3:09:23<8:30:07, 20.75s/it] 21%|██▏ | 400/1875 [3:09:23<8:30:07, 20.75s/it]{'loss': '0.01658', 'grad_norm': '6.312', 'learning_rate': '5e-06', 'num_tokens': '1.047e+07', 'completions/mean_length': '14.44', 'completions/min_length': '7', 'completions/max_length': '43.3', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '14.44', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '43.3', 'rewards/reward_function/mean': '-0.01298', 'rewards/reward_function/std': '0.7544', 'reward': '-0.01298', 'reward_std': '0.7544', 'frac_reward_zero_std': '0.4813', 'kl': '0.2388', 'entropy': '0.2339', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '21.51', 'epoch': '0.496'} +{'loss': '0.05994', 'grad_norm': '29.62', 'learning_rate': '5e-06', 'num_tokens': '1.082e+07', 'completions/mean_length': '17.63', 'completions/min_length': '7', 'completions/max_length': '62.2', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '17.63', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '62.2', 'rewards/reward_function/mean': '-0.03098', 'rewards/reward_function/std': '0.7493', 'reward': '-0.03098', 'reward_std': '0.7493', 'frac_reward_zero_std': '0.5375', 'kl': '0.1627', 'entropy': '0.3015', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '24.23', 'epoch': '0.512'} +{'loss': '0.01882', 'grad_norm': '10.12', 'learning_rate': '5e-06', 'num_tokens': '1.117e+07', 'completions/mean_length': '14.92', 'completions/min_length': '7', 'completions/max_length': '51.8', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '14.92', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '51.8', 'rewards/reward_function/mean': '-0.02476', 'rewards/reward_function/std': '0.7494', 'reward': '-0.02476', 'reward_std': '0.7494', 'frac_reward_zero_std': '0.5062', 'kl': '0.2356', 'entropy': '0.3197', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '22.99', 'epoch': '0.528'} +{'loss': '0.05632', 'grad_norm': '8.062', 'learning_rate': '5e-06', 'num_tokens': '1.15e+07', 'completions/mean_length': '15.45', 'completions/min_length': '7', 'completions/max_length': '49.3', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '15.45', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '49.3', 'rewards/reward_function/mean': '-0.1001', 'rewards/reward_function/std': '0.7398', 'reward': '-0.1001', 'reward_std': '0.7398', 'frac_reward_zero_std': '0.5375', 'kl': '0.2206', 'entropy': '0.2328', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '23.21', 'epoch': '0.544'} +{'loss': '0.03395', 'grad_norm': '10.62', 'learning_rate': '5e-06', 'num_tokens': '1.183e+07', 'completions/mean_length': '14.18', 'completions/min_length': '7', 'completions/max_length': '63.7', 'completions/clipped_ratio': '0.001563', 'completions/mean_terminated_length': '13.8', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '41', 'rewards/reward_function/mean': '-0.03688', 'rewards/reward_function/std': '0.7718', 'reward': '-0.03688', 'reward_std': '0.7718', 'frac_reward_zero_std': '0.6375', 'kl': '0.2389', 'entropy': '0.4318', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '24.92', 'epoch': '0.56'} +{'loss': '0.003839', 'grad_norm': '8.375', 'learning_rate': '5e-06', 'num_tokens': '1.214e+07', 'completions/mean_length': '14.36', 'completions/min_length': '7', 'completions/max_length': '45.9', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '14.36', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '45.9', 'rewards/reward_function/mean': '-0.05828', 'rewards/reward_function/std': '0.7081', 'reward': '-0.05828', 'reward_std': '0.7081', 'frac_reward_zero_std': '0.5625', 'kl': '0.2335', 'entropy': '0.216', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '21.15', 'epoch': '0.576'} +{'loss': '0.05172', 'grad_norm': '6.781', 'learning_rate': '5e-06', 'num_tokens': '1.248e+07', 'completions/mean_length': '15.7', 'completions/min_length': '7', 'completions/max_length': '52.4', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '15.7', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '52.4', 'rewards/reward_function/mean': '-0.09204', 'rewards/reward_function/std': '0.7514', 'reward': '-0.09204', 'reward_std': '0.7514', 'frac_reward_zero_std': '0.575', 'kl': '0.2227', 'entropy': '0.2359', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '23.01', 'epoch': '0.592'} +{'loss': '0.0228', 'grad_norm': '7.906', 'learning_rate': '5e-06', 'num_tokens': '1.282e+07', 'completions/mean_length': '13.63', 'completions/min_length': '7', 'completions/max_length': '57.6', 'completions/clipped_ratio': '0.001563', 'completions/mean_terminated_length': '13.25', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '34.8', 'rewards/reward_function/mean': '-0.1224', 'rewards/reward_function/std': '0.7635', 'reward': '-0.1224', 'reward_std': '0.7635', 'frac_reward_zero_std': '0.6375', 'kl': '0.2676', 'entropy': '0.4111', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '22.92', 'epoch': '0.608'} +{'loss': '0.03393', 'grad_norm': '9.875', 'learning_rate': '5e-06', 'num_tokens': '1.316e+07', 'completions/mean_length': '13.57', 'completions/min_length': '7', 'completions/max_length': '48.8', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '13.57', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '48.8', 'rewards/reward_function/mean': '-0.01127', 'rewards/reward_function/std': '0.7412', 'reward': '-0.01127', 'reward_std': '0.7412', 'frac_reward_zero_std': '0.6188', 'kl': '0.3007', 'entropy': '0.2005', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '23.28', 'epoch': '0.624'} +{'loss': '0.05116', 'grad_norm': '14', 'learning_rate': '5e-06', 'num_tokens': '1.348e+07', 'completions/mean_length': '13.77', 'completions/min_length': '7', 'completions/max_length': '42.4', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '13.77', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '42.4', 'rewards/reward_function/mean': '-0.1463', 'rewards/reward_function/std': '0.7154', 'reward': '-0.1463', 'reward_std': '0.7154', 'frac_reward_zero_std': '0.6062', 'kl': '0.228', 'entropy': '0.2214', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '21.36', 'epoch': '0.64'} + 21%|██▏ | 401/1875 [3:10:49<16:29:50, 40.29s/it] 21%|██▏ | 402/1875 [3:11:09<13:58:25, 34.15s/it] 21%|██▏ | 403/1875 [3:11:32<12:38:25, 30.91s/it] 22%|██▏ | 404/1875 [3:11:53<11:22:16, 27.83s/it] 22%|██▏ | 405/1875 [3:12:15<10:38:44, 26.07s/it] 22%|██▏ | 406/1875 [3:12:36<10:02:38, 24.61s/it] 22%|██▏ | 407/1875 [3:12:58<9:40:32, 23.73s/it] 22%|██▏ | 408/1875 [3:13:16<8:58:56, 22.04s/it] 22%|██▏ | 409/1875 [3:13:41<9:21:30, 22.98s/it] 22%|██▏ | 410/1875 [3:13:59<8:47:28, 21.60s/it] 22%|██▏ | 410/1875 [3:13:59<8:47:28, 21.60s/it] 22%|██▏ | 411/1875 [3:14:19<8:34:41, 21.09s/it] 22%|██▏ | 412/1875 [3:14:40<8:31:05, 20.96s/it] 22%|██▏ | 413/1875 [3:15:01<8:28:57, 20.89s/it] 22%|██▏ | 414/1875 [3:15:18<8:03:01, 19.84s/it] 22%|██▏ | 415/1875 [3:15:40<8:15:13, 20.35s/it] 22%|██▏ | 416/1875 [3:16:24<11:12:38, 27.66s/it] 22%|██▏ | 417/1875 [3:16:45<10:18:52, 25.47s/it] 22%|██▏ | 418/1875 [3:17:09<10:08:23, 25.05s/it] 22%|██▏ | 419/1875 [3:17:32<9:54:59, 24.52s/it] 22%|██▏ | 420/1875 [3:17:49<8:56:36, 22.13s/it] 22%|██▏ | 420/1875 [3:17:49<8:56:36, 22.13s/it] 22%|██▏ | 421/1875 [3:18:11<8:54:42, 22.07s/it] 23%|██▎ | 422/1875 [3:18:30<8:33:13, 21.19s/it] 23%|██▎ | 423/1875 [3:18:54<8:52:20, 22.00s/it] 23%|██▎ | 424/1875 [3:19:16<8:55:36, 22.15s/it] 23%|██▎ | 425/1875 [3:19:36<8:36:11, 21.36s/it] 23%|██▎ | 426/1875 [3:20:02<9:09:15, 22.74s/it] 23%|██▎ | 427/1875 [3:20:23<8:57:16, 22.26s/it] 23%|██▎ | 428/1875 [3:20:46<9:05:48, 22.63s/it] 23%|██▎ | 429/1875 [3:21:07<8:50:10, 22.00s/it] 23%|██▎ | 430/1875 [3:21:32<9:14:22, 23.02s/it] 23%|██▎ | 430/1875 [3:21:32<9:14:22, 23.02s/it] 23%|██▎ | 431/1875 [3:22:00<9:50:33, 24.54s/it] 23%|██▎ | 432/1875 [3:22:22<9:31:57, 23.78s/it] 23%|██▎ | 433/1875 [3:22:46<9:30:50, 23.75s/it] 23%|██▎ | 434/1875 [3:23:07<9:11:20, 22.96s/it] 23%|██▎ | 435/1875 [3:23:29<9:05:01, 22.71s/it] 23%|██▎ | 436/1875 [3:23:53<9:10:03, 22.94s/it] 23%|██▎ | 437/1875 [3:24:17<9:18:57, 23.32s/it] 23%|██▎ | 438/1875 [3:24:38<9:01:10, 22.60s/it] 23%|██▎ | 439/1875 [3:25:21<11:33:04, 28.96s/it] 23%|██▎ | 440/1875 [3:25:42<10:28:54, 26.30s/it] 23%|██▎ | 440/1875 [3:25:42<10:28:54, 26.30s/it] 24%|██▎ | 441/1875 [3:26:02<9:47:17, 24.57s/it] 24%|██▎ | 442/1875 [3:26:25<9:35:30, 24.10s/it] 24%|██▎ | 443/1875 [3:26:48<9:27:40, 23.79s/it] 24%|██▎ | 444/1875 [3:27:07<8:55:02, 22.43s/it] 24%|██▎ | 445/1875 [3:27:35<9:32:43, 24.03s/it] 24%|██▍ | 446/1875 [3:28:01<9:45:41, 24.59s/it] 24%|██▍ | 447/1875 [3:28:19<8:56:27, 22.54s/it] 24%|██▍ | 448/1875 [3:28:41<8:55:05, 22.50s/it] 24%|██▍ | 449/1875 [3:29:00<8:24:23, 21.22s/it] 24%|██▍ | 450/1875 [3:29:23<8:37:04, 21.77s/it] 24%|██▍ | 450/1875 [3:29:23<8:37:04, 21.77s/it] 24%|██▍ | 451/1875 [3:29:43<8:25:22, 21.29s/it] 24%|██▍ | 452/1875 [3:30:02<8:10:19, 20.67s/it] 24%|██▍ | 453/1875 [3:30:27<8:37:55, 21.85s/it] 24%|██▍ | 454/1875 [3:30:48<8:32:33, 21.64s/it] 24%|██▍ | 455/1875 [3:31:05<8:03:53, 20.45s/it] 24%|██▍ | 456/1875 [3:31:24<7:50:28, 19.89s/it] 24%|██▍ | 457/1875 [3:31:40<7:24:05, 18.79s/it] 24%|██▍ | 458/1875 [3:32:01<7:41:09, 19.53s/it] 24%|██▍ | 459/1875 [3:32:25<8:10:03, 20.77s/it] 25%|██▍ | 460/1875 [3:32:49<8:31:49, 21.70s/it] 25%|██▍ | 460/1875 [3:32:49<8:31:49, 21.70s/it] 25%|██▍ | 461/1875 [3:33:12<8:41:36, 22.13s/it] 25%|██▍ | 462/1875 [3:33:36<8:54:16, 22.69s/it] 25%|██▍ | 463/1875 [3:33:59<8:52:46, 22.64s/it] 25%|██▍ | 464/1875 [3:34:19<8:35:10, 21.91s/it] 25%|██▍ | 465/1875 [3:34:42<8:40:17, 22.14s/it] 25%|██▍ | 466/1875 [3:35:06<8:56:30, 22.85s/it] 25%|██▍ | 467/1875 [3:35:26<8:38:39, 22.10s/it] 25%|██▍ | 468/1875 [3:35:46<8:23:11, 21.46s/it] 25%|██▌ | 469/1875 [3:36:09<8:30:28, 21.78s/it] 25%|██▌ | 470/1875 [3:36:34<8:54:57, 22.84s/it] 25%|██▌ | 470/1875 [3:36:34<8:54:57, 22.84s/it] 25%|██▌ | 471/1875 [3:36:54<8:36:03, 22.05s/it] 25%|██▌ | 472/1875 [3:37:19<8:56:55, 22.96s/it] 25%|██▌ | 473/1875 [3:37:39<8:33:39, 21.98s/it] 25%|██▌ | 474/1875 [3:37:57<8:03:43, 20.72s/it] 25%|██▌ | 475/1875 [3:38:17<7:56:57, 20.44s/it] 25%|██▌ | 476/1875 [3:38:36<7:45:14, 19.95s/it] 25%|██▌ | 477/1875 [3:38:57<7:57:47, 20.51s/it] 25%|██▌ | 478/1875 [3:39:16<7:41:15, 19.81s/it] 26%|██▌ | 479/1875 [3:39:38<7:56:33, 20.48s/it] 26%|██▌ | 480/1875 [3:39:53<7:24:02, 19.10s/it] 26%|██▌ | 480/1875 [3:39:53<7:24:02, 19.10s/it] 26%|██▌ | 481/1875 [3:40:13<7:28:31, 19.31s/it] 26%|██▌ | 482/1875 [3:40:33<7:28:26, 19.32s/it] 26%|██▌ | 483/1875 [3:40:53<7:33:31, 19.55s/it] 26%|██▌ | 484/1875 [3:41:13<7:41:16, 19.90s/it] 26%|██▌ | 485/1875 [3:41:36<7:56:40, 20.58s/it] 26%|██▌ | 486/1875 [3:41:57<8:04:19, 20.92s/it] 26%|██▌ | 487/1875 [3:42:23<8:37:56, 22.39s/it] 26%|██▌ | 488/1875 [3:42:46<8:42:26, 22.60s/it] 26%|██▌ | 489/1875 [3:43:04<8:09:31, 21.19s/it] 26%|██▌ | 490/1875 [3:43:23<7:56:31, 20.64s/it] 26%|██▌ | 490/1875 [3:43:23<7:56:31, 20.64s/it] 26%|██▌ | 491/1875 [3:43:45<8:01:42, 20.88s/it] 26%|██▌ | 492/1875 [3:44:05<7:57:33, 20.72s/it] 26%|██▋ | 493/1875 [3:44:22<7:30:52, 19.57s/it] 26%|██▋ | 494/1875 [3:44:40<7:19:16, 19.09s/it] 26%|██▋ | 495/1875 [3:44:56<6:59:36, 18.24s/it] 26%|██▋ | 496/1875 [3:45:18<7:23:22, 19.29s/it] 27%|██▋ | 497/1875 [3:45:35<7:04:28, 18.48s/it] 27%|██▋ | 498/1875 [3:45:53<7:03:47, 18.47s/it] 27%|██▋ | 499/1875 [3:46:14<7:23:13, 19.33s/it] 27%|██▋ | 500/1875 [3:46:35<7:33:40, 19.80s/it] 27%|██▋ | 500/1875 [3:46:35<7:33:40, 19.80s/it]{'loss': '0.05169', 'grad_norm': '6.125', 'learning_rate': '5e-06', 'num_tokens': '1.381e+07', 'completions/mean_length': '13.58', 'completions/min_length': '7', 'completions/max_length': '41.7', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '13.58', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '41.7', 'rewards/reward_function/mean': '0.005292', 'rewards/reward_function/std': '0.7257', 'reward': '0.005292', 'reward_std': '0.7257', 'frac_reward_zero_std': '0.55', 'kl': '0.2661', 'entropy': '0.1861', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '20.97', 'epoch': '0.656'} +{'loss': '0.01523', 'grad_norm': '15.75', 'learning_rate': '5e-06', 'num_tokens': '1.414e+07', 'completions/mean_length': '14.07', 'completions/min_length': '7', 'completions/max_length': '57.6', 'completions/clipped_ratio': '0.001563', 'completions/mean_terminated_length': '13.7', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '36', 'rewards/reward_function/mean': '0.02591', 'rewards/reward_function/std': '0.7276', 'reward': '0.02591', 'reward_std': '0.7276', 'frac_reward_zero_std': '0.5563', 'kl': '0.2456', 'entropy': '0.3872', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '22.83', 'epoch': '0.672'} +{'loss': '0.01811', 'grad_norm': '5.469', 'learning_rate': '5e-06', 'num_tokens': '1.446e+07', 'completions/mean_length': '15.1', 'completions/min_length': '7', 'completions/max_length': '47', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '15.1', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '47', 'rewards/reward_function/mean': '-0.06534', 'rewards/reward_function/std': '0.7479', 'reward': '-0.06534', 'reward_std': '0.7479', 'frac_reward_zero_std': '0.5437', 'kl': '0.1931', 'entropy': '0.2438', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '22.27', 'epoch': '0.688'} +{'loss': '0.04902', 'grad_norm': '208', 'learning_rate': '5e-06', 'num_tokens': '1.48e+07', 'completions/mean_length': '15.8', 'completions/min_length': '7', 'completions/max_length': '68.5', 'completions/clipped_ratio': '0.001563', 'completions/mean_terminated_length': '15.43', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '49.3', 'rewards/reward_function/mean': '-0.01271', 'rewards/reward_function/std': '0.7409', 'reward': '-0.01271', 'reward_std': '0.7409', 'frac_reward_zero_std': '0.575', 'kl': '0.191', 'entropy': '0.4241', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '24.86', 'epoch': '0.704'} +{'loss': '0.03248', 'grad_norm': '4.344', 'learning_rate': '5e-06', 'num_tokens': '1.513e+07', 'completions/mean_length': '14.99', 'completions/min_length': '7', 'completions/max_length': '38.9', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '14.99', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '38.9', 'rewards/reward_function/mean': '-0.03651', 'rewards/reward_function/std': '0.7479', 'reward': '-0.03651', 'reward_std': '0.7479', 'frac_reward_zero_std': '0.5375', 'kl': '0.2514', 'entropy': '0.1759', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '22.01', 'epoch': '0.72'} +{'loss': '0.01309', 'grad_norm': '4.5', 'learning_rate': '5e-06', 'num_tokens': '1.546e+07', 'completions/mean_length': '14.03', 'completions/min_length': '7', 'completions/max_length': '34.5', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '14.03', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '34.5', 'rewards/reward_function/mean': '0.03686', 'rewards/reward_function/std': '0.7424', 'reward': '0.03686', 'reward_std': '0.7424', 'frac_reward_zero_std': '0.5813', 'kl': '0.2381', 'entropy': '0.1702', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '20.56', 'epoch': '0.736'} +{'loss': '0.0229', 'grad_norm': '6.219', 'learning_rate': '5e-06', 'num_tokens': '1.581e+07', 'completions/mean_length': '14.38', 'completions/min_length': '7', 'completions/max_length': '43.4', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '14.38', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '43.4', 'rewards/reward_function/mean': '0.02769', 'rewards/reward_function/std': '0.7268', 'reward': '0.02769', 'reward_std': '0.7268', 'frac_reward_zero_std': '0.6188', 'kl': '0.209', 'entropy': '0.1754', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '22.43', 'epoch': '0.752'} +{'loss': '0.03038', 'grad_norm': '11.5', 'learning_rate': '5e-06', 'num_tokens': '1.61e+07', 'completions/mean_length': '14.03', 'completions/min_length': '7', 'completions/max_length': '45.3', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '14.03', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '45.3', 'rewards/reward_function/mean': '-0.02106', 'rewards/reward_function/std': '0.76', 'reward': '-0.02106', 'reward_std': '0.76', 'frac_reward_zero_std': '0.5437', 'kl': '0.2118', 'entropy': '0.2059', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '19.85', 'epoch': '0.768'} +{'loss': '0.04238', 'grad_norm': '9.312', 'learning_rate': '5e-06', 'num_tokens': '1.643e+07', 'completions/mean_length': '14.4', 'completions/min_length': '7', 'completions/max_length': '38', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '14.4', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '38', 'rewards/reward_function/mean': '0.02066', 'rewards/reward_function/std': '0.7214', 'reward': '0.02066', 'reward_std': '0.7214', 'frac_reward_zero_std': '0.6062', 'kl': '0.2055', 'entropy': '0.2173', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '20.92', 'epoch': '0.784'} +{'loss': '0.01798', 'grad_norm': '4.375', 'learning_rate': '5e-06', 'num_tokens': '1.672e+07', 'completions/mean_length': '12.77', 'completions/min_length': '7', 'completions/max_length': '34.8', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '12.77', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '34.8', 'rewards/reward_function/mean': '-0.00801', 'rewards/reward_function/std': '0.7225', 'reward': '-0.00801', 'reward_std': '0.7225', 'frac_reward_zero_std': '0.5875', 'kl': '0.2778', 'entropy': '0.1682', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '19.11', 'epoch': '0.8'} + 27%|██▋ | 501/1875 [3:47:58<14:46:02, 38.69s/it] 27%|██▋ | 502/1875 [3:48:21<12:57:31, 33.98s/it] 27%|██▋ | 503/1875 [3:48:40<11:10:52, 29.34s/it] 27%|██▋ | 504/1875 [3:49:02<10:23:56, 27.31s/it] 27%|██▋ | 505/1875 [3:49:26<10:01:12, 26.33s/it] 27%|██▋ | 506/1875 [3:49:45<9:09:15, 24.07s/it] 27%|██▋ | 507/1875 [3:50:08<9:01:24, 23.75s/it] 27%|██▋ | 508/1875 [3:50:27<8:27:34, 22.28s/it] 27%|██▋ | 509/1875 [3:50:45<7:59:49, 21.08s/it] 27%|██▋ | 510/1875 [3:51:09<8:18:09, 21.90s/it] 27%|██▋ | 510/1875 [3:51:09<8:18:09, 21.90s/it] 27%|██▋ | 511/1875 [3:51:28<7:59:49, 21.11s/it] 27%|██▋ | 512/1875 [3:51:50<8:05:05, 21.35s/it] 27%|██▋ | 513/1875 [3:52:08<7:43:00, 20.40s/it] 27%|██▋ | 514/1875 [3:52:30<7:50:08, 20.73s/it] 27%|██▋ | 515/1875 [3:52:50<7:47:37, 20.63s/it] 28%|██▊ | 516/1875 [3:53:13<8:00:35, 21.22s/it] 28%|██▊ | 517/1875 [3:53:36<8:10:48, 21.68s/it] 28%|██▊ | 518/1875 [3:53:55<7:56:20, 21.06s/it] 28%|██▊ | 519/1875 [3:54:22<8:36:22, 22.85s/it] 28%|██▊ | 520/1875 [3:54:40<8:01:57, 21.34s/it] 28%|██▊ | 520/1875 [3:54:40<8:01:57, 21.34s/it] 28%|██▊ | 521/1875 [3:54:58<7:40:50, 20.42s/it] 28%|██▊ | 522/1875 [3:55:17<7:31:31, 20.02s/it] 28%|██▊ | 523/1875 [3:55:36<7:24:12, 19.71s/it] 28%|██▊ | 524/1875 [3:56:00<7:53:01, 21.01s/it] 28%|██▊ | 525/1875 [3:56:22<7:59:06, 21.29s/it] 28%|██▊ | 526/1875 [3:56:42<7:44:09, 20.64s/it] 28%|██▊ | 527/1875 [3:57:03<7:46:42, 20.77s/it] 28%|██▊ | 528/1875 [3:57:24<7:53:35, 21.10s/it] 28%|██▊ | 529/1875 [3:57:42<7:30:46, 20.09s/it] 28%|██▊ | 530/1875 [3:58:01<7:20:40, 19.66s/it] 28%|██▊ | 530/1875 [3:58:01<7:20:40, 19.66s/it] 28%|██▊ | 531/1875 [3:58:17<6:58:43, 18.69s/it] 28%|██▊ | 532/1875 [3:58:36<6:58:06, 18.68s/it] 28%|██▊ | 533/1875 [3:59:00<7:31:01, 20.17s/it] 28%|██▊ | 534/1875 [3:59:18<7:15:53, 19.50s/it] 29%|██▊ | 535/1875 [3:59:37<7:18:30, 19.63s/it] 29%|██▊ | 536/1875 [3:59:55<7:05:45, 19.08s/it] 29%|██▊ | 537/1875 [4:00:16<7:16:58, 19.60s/it] 29%|██▊ | 538/1875 [4:00:40<7:48:35, 21.03s/it] 29%|██▊ | 539/1875 [4:01:02<7:53:44, 21.28s/it] 29%|██▉ | 540/1875 [4:01:17<7:12:25, 19.44s/it] 29%|██▉ | 540/1875 [4:01:17<7:12:25, 19.44s/it] 29%|██▉ | 541/1875 [4:01:32<6:40:56, 18.03s/it] 29%|██▉ | 542/1875 [4:01:56<7:22:19, 19.91s/it] 29%|██▉ | 543/1875 [4:02:17<7:23:39, 19.98s/it] 29%|██▉ | 544/1875 [4:02:35<7:11:53, 19.47s/it] 29%|██▉ | 545/1875 [4:02:56<7:20:21, 19.87s/it] 29%|██▉ | 546/1875 [4:03:22<8:02:12, 21.77s/it] 29%|██▉ | 547/1875 [4:03:46<8:18:45, 22.53s/it] 29%|██▉ | 548/1875 [4:04:05<7:50:40, 21.28s/it] 29%|██▉ | 549/1875 [4:04:27<7:59:42, 21.71s/it] 29%|██▉ | 550/1875 [4:04:46<7:38:18, 20.75s/it] 29%|██▉ | 550/1875 [4:04:46<7:38:18, 20.75s/it] 29%|██▉ | 551/1875 [4:05:08<7:49:08, 21.26s/it] 29%|██▉ | 552/1875 [4:05:27<7:31:48, 20.49s/it] 29%|██▉ | 553/1875 [4:05:45<7:13:16, 19.66s/it] 30%|██▉ | 554/1875 [4:06:03<7:04:49, 19.30s/it] 30%|██▉ | 555/1875 [4:06:23<7:10:22, 19.56s/it] 30%|██▉ | 556/1875 [4:06:46<7:28:12, 20.39s/it] 30%|██▉ | 557/1875 [4:07:07<7:35:12, 20.72s/it] 30%|██▉ | 558/1875 [4:07:29<7:43:28, 21.11s/it] 30%|██▉ | 559/1875 [4:07:49<7:31:56, 20.61s/it] 30%|██▉ | 560/1875 [4:08:09<7:27:38, 20.42s/it] 30%|██▉ | 560/1875 [4:08:09<7:27:38, 20.42s/it] 30%|██▉ | 561/1875 [4:08:27<7:14:41, 19.85s/it] 30%|██▉ | 562/1875 [4:08:46<7:07:52, 19.55s/it] 30%|███ | 563/1875 [4:09:11<7:46:20, 21.33s/it] 30%|███ | 564/1875 [4:09:33<7:46:13, 21.34s/it] 30%|███ | 565/1875 [4:09:55<7:52:53, 21.66s/it] 30%|███ | 566/1875 [4:10:16<7:50:07, 21.55s/it] 30%|███ | 567/1875 [4:10:37<7:41:54, 21.19s/it] 30%|███ | 568/1875 [4:10:56<7:27:16, 20.53s/it] 30%|███ | 569/1875 [4:11:19<7:46:45, 21.44s/it] 30%|███ | 570/1875 [4:11:39<7:34:16, 20.89s/it] 30%|███ | 570/1875 [4:11:39<7:34:16, 20.89s/it] 30%|███ | 571/1875 [4:11:57<7:17:13, 20.12s/it] 31%|███ | 572/1875 [4:12:23<7:52:34, 21.76s/it] 31%|███ | 573/1875 [4:12:43<7:42:32, 21.31s/it] 31%|███ | 574/1875 [4:13:06<7:50:50, 21.71s/it] 31%|███ | 575/1875 [4:13:27<7:48:38, 21.63s/it] 31%|███ | 576/1875 [4:13:53<8:14:06, 22.82s/it] 31%|███ | 577/1875 [4:14:11<7:40:53, 21.30s/it] 31%|███ | 578/1875 [4:14:33<7:50:39, 21.77s/it] 31%|███ | 579/1875 [4:14:52<7:30:23, 20.85s/it] 31%|███ | 580/1875 [4:15:08<6:56:36, 19.30s/it] 31%|███ | 580/1875 [4:15:08<6:56:36, 19.30s/it] 31%|███ | 581/1875 [4:15:34<7:42:45, 21.46s/it] 31%|███ | 582/1875 [4:15:45<6:33:27, 18.26s/it] 31%|███ | 583/1875 [4:16:06<6:51:49, 19.12s/it] 31%|███ | 584/1875 [4:16:23<6:38:23, 18.52s/it] 31%|███ | 585/1875 [4:16:44<6:50:31, 19.09s/it] 31%|███▏ | 586/1875 [4:17:06<7:11:13, 20.07s/it] 31%|███▏ | 587/1875 [4:17:24<6:57:34, 19.45s/it] 31%|███▏ | 588/1875 [4:17:42<6:48:52, 19.06s/it] 31%|███▏ | 589/1875 [4:18:07<7:24:25, 20.74s/it] 31%|███▏ | 590/1875 [4:18:29<7:31:43, 21.09s/it] 31%|███▏ | 590/1875 [4:18:29<7:31:43, 21.09s/it] 32%|███▏ | 591/1875 [4:18:52<7:45:18, 21.74s/it] 32%|███▏ | 592/1875 [4:19:14<7:45:37, 21.78s/it] 32%|███▏ | 593/1875 [4:19:31<7:12:29, 20.24s/it] 32%|███▏ | 594/1875 [4:19:50<7:08:51, 20.09s/it] 32%|███▏ | 595/1875 [4:20:09<6:59:06, 19.65s/it] 32%|███▏ | 596/1875 [4:20:32<7:21:12, 20.70s/it] 32%|███▏ | 597/1875 [4:20:54<7:30:02, 21.13s/it] 32%|███▏ | 598/1875 [4:21:15<7:23:49, 20.85s/it] 32%|███▏ | 599/1875 [4:21:37<7:36:50, 21.48s/it] 32%|███▏ | 600/1875 [4:21:58<7:27:57, 21.08s/it] 32%|███▏ | 600/1875 [4:21:58<7:27:57, 21.08s/it]{'loss': '0.01209', 'grad_norm': '16.25', 'learning_rate': '5e-06', 'num_tokens': '1.705e+07', 'completions/mean_length': '13.72', 'completions/min_length': '7', 'completions/max_length': '30.4', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '13.72', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '30.4', 'rewards/reward_function/mean': '0.08989', 'rewards/reward_function/std': '0.7085', 'reward': '0.08989', 'reward_std': '0.7085', 'frac_reward_zero_std': '0.6625', 'kl': '0.2462', 'entropy': '0.1423', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '20.88', 'epoch': '0.816'} +{'loss': '0.03224', 'grad_norm': '7.75', 'learning_rate': '5e-06', 'num_tokens': '1.739e+07', 'completions/mean_length': '13.3', 'completions/min_length': '7', 'completions/max_length': '33.2', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '13.3', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '33.2', 'rewards/reward_function/mean': '0.07375', 'rewards/reward_function/std': '0.6751', 'reward': '0.07375', 'reward_std': '0.6751', 'frac_reward_zero_std': '0.6062', 'kl': '0.2812', 'entropy': '0.1487', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '21.02', 'epoch': '0.832'} +{'loss': '0.02197', 'grad_norm': '4.094', 'learning_rate': '5e-06', 'num_tokens': '1.772e+07', 'completions/mean_length': '12.96', 'completions/min_length': '7', 'completions/max_length': '32', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '12.96', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '32', 'rewards/reward_function/mean': '0.02511', 'rewards/reward_function/std': '0.7073', 'reward': '0.02511', 'reward_std': '0.7073', 'frac_reward_zero_std': '0.5875', 'kl': '0.2517', 'entropy': '0.1322', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '20', 'epoch': '0.848'} +{'loss': '0.01886', 'grad_norm': '5.375', 'learning_rate': '5e-06', 'num_tokens': '1.804e+07', 'completions/mean_length': '12.94', 'completions/min_length': '7', 'completions/max_length': '30.5', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '12.94', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '30.5', 'rewards/reward_function/mean': '0.007406', 'rewards/reward_function/std': '0.7242', 'reward': '0.007406', 'reward_std': '0.7242', 'frac_reward_zero_std': '0.675', 'kl': '0.2638', 'entropy': '0.1527', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '19.57', 'epoch': '0.864'} +{'loss': '0.008496', 'grad_norm': '7.062', 'learning_rate': '5e-06', 'num_tokens': '1.838e+07', 'completions/mean_length': '13.48', 'completions/min_length': '7', 'completions/max_length': '36.9', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '13.48', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '36.9', 'rewards/reward_function/mean': '-0.002011', 'rewards/reward_function/std': '0.7264', 'reward': '-0.002011', 'reward_std': '0.7264', 'frac_reward_zero_std': '0.6375', 'kl': '0.2658', 'entropy': '0.1613', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '20.75', 'epoch': '0.88'} +{'loss': '0.02598', 'grad_norm': '12.88', 'learning_rate': '5e-06', 'num_tokens': '1.869e+07', 'completions/mean_length': '13.59', 'completions/min_length': '7', 'completions/max_length': '30.4', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '13.59', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '30.4', 'rewards/reward_function/mean': '0.1092', 'rewards/reward_function/std': '0.7165', 'reward': '0.1092', 'reward_std': '0.7165', 'frac_reward_zero_std': '0.6062', 'kl': '0.2658', 'entropy': '0.1389', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '20.2', 'epoch': '0.896'} +{'loss': '0.02766', 'grad_norm': '14.56', 'learning_rate': '5e-06', 'num_tokens': '1.901e+07', 'completions/mean_length': '13.51', 'completions/min_length': '7', 'completions/max_length': '36.4', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '13.51', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '36.4', 'rewards/reward_function/mean': '-0.004352', 'rewards/reward_function/std': '0.6892', 'reward': '-0.004352', 'reward_std': '0.6892', 'frac_reward_zero_std': '0.6188', 'kl': '0.2559', 'entropy': '0.1556', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '20.96', 'epoch': '0.912'} +{'loss': '0.02565', 'grad_norm': '7.062', 'learning_rate': '5e-06', 'num_tokens': '1.933e+07', 'completions/mean_length': '13.19', 'completions/min_length': '7', 'completions/max_length': '42.6', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '13.19', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '42.6', 'rewards/reward_function/mean': '-0.06826', 'rewards/reward_function/std': '0.7174', 'reward': '-0.06826', 'reward_std': '0.7174', 'frac_reward_zero_std': '0.5687', 'kl': '0.2922', 'entropy': '0.182', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '20.76', 'epoch': '0.928'} +{'loss': '0.02615', 'grad_norm': '8.5', 'learning_rate': '5e-06', 'num_tokens': '1.964e+07', 'completions/mean_length': '13.4', 'completions/min_length': '7', 'completions/max_length': '36.6', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '13.4', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '36.6', 'rewards/reward_function/mean': '0.00657', 'rewards/reward_function/std': '0.6908', 'reward': '0.00657', 'reward_std': '0.6908', 'frac_reward_zero_std': '0.6062', 'kl': '0.2227', 'entropy': '0.1446', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '20.02', 'epoch': '0.944'} +{'loss': '0.02896', 'grad_norm': '6.812', 'learning_rate': '5e-06', 'num_tokens': '1.996e+07', 'completions/mean_length': '13.45', 'completions/min_length': '7', 'completions/max_length': '35.5', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '13.45', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '35.5', 'rewards/reward_function/mean': '-0.03688', 'rewards/reward_function/std': '0.7151', 'reward': '-0.03688', 'reward_std': '0.7151', 'frac_reward_zero_std': '0.575', 'kl': '0.2406', 'entropy': '0.1538', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '20.79', 'epoch': '0.96'} + 32%|███▏ | 601/1875 [4:23:27<14:44:30, 41.66s/it] 32%|███▏ | 602/1875 [4:23:47<12:24:14, 35.08s/it] 32%|███▏ | 603/1875 [4:24:12<11:18:41, 32.01s/it] 32%|███▏ | 604/1875 [4:24:34<10:14:45, 29.02s/it] 32%|███▏ | 605/1875 [4:24:51<9:01:30, 25.58s/it] 32%|███▏ | 606/1875 [4:25:09<8:11:15, 23.23s/it] 32%|███▏ | 607/1875 [4:25:30<7:54:26, 22.45s/it] 32%|███▏ | 608/1875 [4:25:50<7:39:04, 21.74s/it] 32%|███▏ | 609/1875 [4:26:08<7:13:55, 20.57s/it] 33%|███▎ | 610/1875 [4:26:25<6:50:57, 19.49s/it] 33%|███▎ | 610/1875 [4:26:25<6:50:57, 19.49s/it] 33%|███▎ | 611/1875 [4:26:48<7:16:42, 20.73s/it] 33%|███▎ | 612/1875 [4:27:12<7:37:04, 21.71s/it] 33%|███▎ | 613/1875 [4:27:34<7:34:42, 21.62s/it] 33%|███▎ | 614/1875 [4:27:51<7:08:51, 20.41s/it] 33%|███▎ | 615/1875 [4:28:17<7:41:01, 21.95s/it] 33%|███▎ | 616/1875 [4:28:42<7:58:09, 22.79s/it] 33%|███▎ | 617/1875 [4:28:59<7:21:04, 21.04s/it] 33%|███▎ | 618/1875 [4:29:19<7:15:42, 20.80s/it] 33%|███▎ | 619/1875 [4:29:40<7:15:01, 20.78s/it] 33%|███▎ | 620/1875 [4:30:00<7:13:47, 20.74s/it] 33%|███▎ | 620/1875 [4:30:00<7:13:47, 20.74s/it] 33%|███▎ | 621/1875 [4:30:23<7:26:24, 21.36s/it] 33%|███▎ | 622/1875 [4:30:44<7:21:45, 21.15s/it] 33%|███▎ | 623/1875 [4:31:08<7:43:37, 22.22s/it] 33%|███▎ | 624/1875 [4:31:27<7:18:27, 21.03s/it] 33%|███▎ | 625/1875 [4:31:47<7:10:58, 20.69s/it] 33%|███▎ | 626/1875 [4:32:07<7:11:05, 20.71s/it] 33%|███▎ | 627/1875 [4:32:23<6:40:11, 19.24s/it] 33%|███▎ | 628/1875 [4:32:40<6:25:49, 18.56s/it] 34%|███▎ | 629/1875 [4:32:57<6:15:54, 18.10s/it] 34%|███▎ | 630/1875 [4:33:18<6:33:30, 18.96s/it] 34%|███▎ | 630/1875 [4:33:18<6:33:30, 18.96s/it] 34%|███▎ | 631/1875 [4:33:38<6:38:11, 19.21s/it] 34%|███▎ | 632/1875 [4:33:58<6:43:26, 19.47s/it] 34%|███▍ | 633/1875 [4:34:19<6:51:12, 19.87s/it] 34%|███▍ | 634/1875 [4:34:42<7:13:37, 20.97s/it] 34%|███▍ | 635/1875 [4:35:05<7:22:58, 21.43s/it] 34%|███▍ | 636/1875 [4:35:27<7:26:57, 21.64s/it] 34%|███▍ | 637/1875 [4:35:48<7:20:50, 21.37s/it] 34%|███▍ | 638/1875 [4:36:06<7:03:13, 20.53s/it] 34%|███▍ | 639/1875 [4:36:25<6:54:27, 20.12s/it] 34%|███▍ | 640/1875 [4:36:45<6:53:09, 20.07s/it] 34%|███▍ | 640/1875 [4:36:45<6:53:09, 20.07s/it] 34%|███▍ | 641/1875 [4:37:03<6:36:54, 19.30s/it] 34%|███▍ | 642/1875 [4:37:27<7:07:16, 20.79s/it] 34%|███▍ | 643/1875 [4:37:49<7:15:30, 21.21s/it] 34%|███▍ | 644/1875 [4:38:07<6:56:33, 20.30s/it] 34%|███▍ | 645/1875 [4:38:27<6:52:27, 20.12s/it] 34%|███▍ | 646/1875 [4:38:51<7:16:52, 21.33s/it] 35%|███▍ | 647/1875 [4:39:10<6:59:45, 20.51s/it] 35%|███▍ | 648/1875 [4:39:30<6:55:11, 20.30s/it] 35%|███▍ | 649/1875 [4:39:52<7:07:44, 20.93s/it] 35%|███▍ | 650/1875 [4:40:12<6:59:08, 20.53s/it] 35%|███▍ | 650/1875 [4:40:12<6:59:08, 20.53s/it] 35%|███▍ | 651/1875 [4:40:32<6:58:56, 20.54s/it] 35%|███▍ | 652/1875 [4:40:52<6:54:17, 20.33s/it] 35%|███▍ | 653/1875 [4:41:11<6:42:47, 19.78s/it] 35%|███▍ | 654/1875 [4:41:31<6:43:13, 19.81s/it] 35%|███▍ | 655/1875 [4:41:49<6:35:50, 19.47s/it] 35%|███▍ | 656/1875 [4:42:11<6:48:22, 20.10s/it] 35%|███▌ | 657/1875 [4:42:35<7:10:46, 21.22s/it] 35%|███▌ | 658/1875 [4:42:54<7:00:08, 20.71s/it] 35%|███▌ | 659/1875 [4:43:11<6:39:32, 19.71s/it] 35%|███▌ | 660/1875 [4:43:34<6:58:24, 20.66s/it] 35%|███▌ | 660/1875 [4:43:34<6:58:24, 20.66s/it] 35%|███▌ | 661/1875 [4:43:57<7:08:33, 21.18s/it] 35%|███▌ | 662/1875 [4:44:17<7:01:01, 20.83s/it] 35%|███▌ | 663/1875 [4:44:35<6:47:59, 20.20s/it] 35%|███▌ | 664/1875 [4:44:54<6:39:33, 19.80s/it] 35%|███▌ | 665/1875 [4:45:12<6:28:19, 19.26s/it] 36%|███▌ | 666/1875 [4:45:33<6:38:34, 19.78s/it] 36%|███▌ | 667/1875 [4:45:58<7:07:42, 21.24s/it] 36%|███▌ | 668/1875 [4:46:17<6:56:46, 20.72s/it] 36%|███▌ | 669/1875 [4:46:40<7:04:32, 21.12s/it] 36%|███▌ | 670/1875 [4:47:02<7:14:01, 21.61s/it] 36%|███▌ | 670/1875 [4:47:02<7:14:01, 21.61s/it] 36%|███▌ | 671/1875 [4:47:23<7:05:13, 21.19s/it] 36%|███▌ | 672/1875 [4:47:44<7:05:14, 21.21s/it] 36%|███▌ | 673/1875 [4:48:02<6:46:18, 20.28s/it] 36%|███▌ | 674/1875 [4:48:21<6:36:11, 19.79s/it] 36%|███▌ | 675/1875 [4:48:41<6:38:45, 19.94s/it] 36%|███▌ | 676/1875 [4:49:01<6:42:28, 20.14s/it] 36%|███▌ | 677/1875 [4:49:22<6:42:17, 20.15s/it] 36%|███▌ | 678/1875 [4:49:39<6:26:42, 19.38s/it] 36%|███▌ | 679/1875 [4:50:01<6:40:22, 20.09s/it] 36%|███▋ | 680/1875 [4:50:18<6:23:29, 19.25s/it] 36%|███▋ | 680/1875 [4:50:18<6:23:29, 19.25s/it] 36%|███▋ | 681/1875 [4:50:42<6:48:09, 20.51s/it] 36%|███▋ | 682/1875 [4:51:04<7:00:11, 21.13s/it] 36%|███▋ | 683/1875 [4:51:29<7:20:33, 22.18s/it] 36%|███▋ | 684/1875 [4:51:48<7:03:25, 21.33s/it] 37%|███▋ | 685/1875 [4:52:10<7:08:17, 21.59s/it] 37%|███▋ | 686/1875 [4:52:31<7:00:23, 21.21s/it] 37%|███▋ | 687/1875 [4:52:49<6:43:47, 20.39s/it] 37%|███▋ | 688/1875 [4:53:08<6:33:08, 19.87s/it] 37%|███▋ | 689/1875 [4:53:31<6:49:59, 20.74s/it] 37%|███▋ | 690/1875 [4:53:47<6:24:20, 19.46s/it] 37%|███▋ | 690/1875 [4:53:47<6:24:20, 19.46s/it] 37%|███▋ | 691/1875 [4:54:05<6:14:11, 18.96s/it] 37%|███▋ | 692/1875 [4:54:27<6:31:08, 19.84s/it] 37%|███▋ | 693/1875 [4:54:49<6:43:04, 20.46s/it] 37%|███▋ | 694/1875 [4:55:08<6:35:06, 20.07s/it] 37%|███▋ | 695/1875 [4:55:27<6:26:07, 19.63s/it] 37%|███▋ | 696/1875 [4:55:50<6:49:19, 20.83s/it] 37%|███▋ | 697/1875 [4:56:10<6:42:01, 20.48s/it] 37%|███▋ | 698/1875 [4:56:29<6:35:19, 20.15s/it] 37%|███▋ | 699/1875 [4:56:50<6:38:14, 20.32s/it] 37%|███▋ | 700/1875 [4:57:09<6:28:09, 19.82s/it] 37%|███▋ | 700/1875 [4:57:09<6:28:09, 19.82s/it]{'loss': '0.01717', 'grad_norm': '11.5', 'learning_rate': '5e-06', 'num_tokens': '2.029e+07', 'completions/mean_length': '13.2', 'completions/min_length': '7.3', 'completions/max_length': '33.3', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '13.2', 'completions/min_terminated_length': '7.3', 'completions/max_terminated_length': '33.3', 'rewards/reward_function/mean': '0.01353', 'rewards/reward_function/std': '0.6911', 'reward': '0.01353', 'reward_std': '0.6911', 'frac_reward_zero_std': '0.6188', 'kl': '0.2293', 'entropy': '0.1748', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '20.02', 'epoch': '0.976'} +{'loss': '0.01576', 'grad_norm': '7.594', 'learning_rate': '5e-06', 'num_tokens': '2.062e+07', 'completions/mean_length': '13.43', 'completions/min_length': '7', 'completions/max_length': '34.5', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '13.43', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '34.5', 'rewards/reward_function/mean': '0.004465', 'rewards/reward_function/std': '0.6929', 'reward': '0.004465', 'reward_std': '0.6929', 'frac_reward_zero_std': '0.5875', 'kl': '0.2299', 'entropy': '0.1584', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '21.46', 'epoch': '0.992'} +{'loss': '0.02128', 'grad_norm': '4.688', 'learning_rate': '5e-06', 'num_tokens': '2.095e+07', 'completions/mean_length': '13.58', 'completions/min_length': '7', 'completions/max_length': '37.8', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '13.58', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '37.8', 'rewards/reward_function/mean': '0.07888', 'rewards/reward_function/std': '0.6833', 'reward': '0.07888', 'reward_std': '0.6833', 'frac_reward_zero_std': '0.6125', 'kl': '0.242', 'entropy': '0.2066', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '19.7', 'epoch': '1.008'} +{'loss': '0.02832', 'grad_norm': '7.031', 'learning_rate': '5e-06', 'num_tokens': '2.127e+07', 'completions/mean_length': '14.64', 'completions/min_length': '7', 'completions/max_length': '37.8', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '14.64', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '37.8', 'rewards/reward_function/mean': '0.1458', 'rewards/reward_function/std': '0.6771', 'reward': '0.1458', 'reward_std': '0.6771', 'frac_reward_zero_std': '0.5938', 'kl': '0.213', 'entropy': '0.1243', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '20.64', 'epoch': '1.024'} +{'loss': '0.0269', 'grad_norm': '11.44', 'learning_rate': '5e-06', 'num_tokens': '2.161e+07', 'completions/mean_length': '13.44', 'completions/min_length': '7', 'completions/max_length': '34', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '13.44', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '34', 'rewards/reward_function/mean': '0.06251', 'rewards/reward_function/std': '0.7168', 'reward': '0.06251', 'reward_std': '0.7168', 'frac_reward_zero_std': '0.6375', 'kl': '0.2525', 'entropy': '0.1479', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '20.55', 'epoch': '1.04'} +{'loss': '0.02449', 'grad_norm': '12.44', 'learning_rate': '5e-06', 'num_tokens': '2.192e+07', 'completions/mean_length': '12.79', 'completions/min_length': '7', 'completions/max_length': '30.3', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '12.79', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '30.3', 'rewards/reward_function/mean': '0.03252', 'rewards/reward_function/std': '0.7118', 'reward': '0.03252', 'reward_std': '0.7118', 'frac_reward_zero_std': '0.675', 'kl': '0.2639', 'entropy': '0.1426', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '20.18', 'epoch': '1.056'} +{'loss': '0.03239', 'grad_norm': '8.812', 'learning_rate': '5e-06', 'num_tokens': '2.224e+07', 'completions/mean_length': '12.78', 'completions/min_length': '7', 'completions/max_length': '34.1', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '12.78', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '34.1', 'rewards/reward_function/mean': '0.02492', 'rewards/reward_function/std': '0.68', 'reward': '0.02492', 'reward_std': '0.68', 'frac_reward_zero_std': '0.6375', 'kl': '0.3384', 'entropy': '0.1287', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '20.71', 'epoch': '1.072'} +{'loss': '0.02644', 'grad_norm': '6.594', 'learning_rate': '5e-06', 'num_tokens': '2.255e+07', 'completions/mean_length': '12.51', 'completions/min_length': '7', 'completions/max_length': '33.4', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '12.51', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '33.4', 'rewards/reward_function/mean': '-0.02095', 'rewards/reward_function/std': '0.7234', 'reward': '-0.02095', 'reward_std': '0.7234', 'frac_reward_zero_std': '0.7125', 'kl': '0.2859', 'entropy': '0.1189', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '19.51', 'epoch': '1.088'} +{'loss': '0.0134', 'grad_norm': '10.12', 'learning_rate': '5e-06', 'num_tokens': '2.288e+07', 'completions/mean_length': '12.44', 'completions/min_length': '7', 'completions/max_length': '28.9', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '12.44', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '28.9', 'rewards/reward_function/mean': '-0.01558', 'rewards/reward_function/std': '0.7076', 'reward': '-0.01558', 'reward_std': '0.7076', 'frac_reward_zero_std': '0.6813', 'kl': '0.3111', 'entropy': '0.1502', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '20.81', 'epoch': '1.104'} +{'loss': '0.007574', 'grad_norm': '13.12', 'learning_rate': '5e-06', 'num_tokens': '2.322e+07', 'completions/mean_length': '13.07', 'completions/min_length': '7', 'completions/max_length': '25.9', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '13.07', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '25.9', 'rewards/reward_function/mean': '0.003545', 'rewards/reward_function/std': '0.7715', 'reward': '0.003545', 'reward_std': '0.7715', 'frac_reward_zero_std': '0.6625', 'kl': '0.3013', 'entropy': '0.1259', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '20.05', 'epoch': '1.12'} + 37%|███▋ | 701/1875 [4:58:36<13:05:42, 40.16s/it] 37%|███▋ | 702/1875 [4:58:54<10:54:50, 33.50s/it] 37%|███▋ | 703/1875 [4:59:15<9:37:48, 29.58s/it] 38%|███▊ | 704/1875 [4:59:32<8:28:15, 26.04s/it] 38%|███▊ | 705/1875 [4:59:52<7:50:21, 24.12s/it] 38%|███▊ | 706/1875 [5:00:09<7:10:25, 22.09s/it] 38%|███▊ | 707/1875 [5:00:28<6:48:37, 20.99s/it] 38%|███▊ | 708/1875 [5:00:46<6:30:55, 20.10s/it] 38%|███▊ | 709/1875 [5:01:03<6:15:51, 19.34s/it] 38%|███▊ | 710/1875 [5:01:25<6:29:58, 20.08s/it] 38%|███▊ | 710/1875 [5:01:25<6:29:58, 20.08s/it] 38%|███▊ | 711/1875 [5:01:46<6:33:50, 20.30s/it] 38%|███▊ | 712/1875 [5:02:04<6:17:57, 19.50s/it] 38%|███▊ | 713/1875 [5:02:22<6:09:07, 19.06s/it] 38%|███▊ | 714/1875 [5:02:46<6:38:59, 20.62s/it] 38%|███▊ | 715/1875 [5:03:05<6:29:29, 20.15s/it] 38%|███▊ | 716/1875 [5:03:25<6:27:24, 20.06s/it] 38%|███▊ | 717/1875 [5:03:44<6:25:05, 19.95s/it] 38%|███▊ | 718/1875 [5:04:03<6:18:18, 19.62s/it] 38%|███▊ | 719/1875 [5:04:23<6:19:35, 19.70s/it] 38%|███▊ | 720/1875 [5:04:46<6:34:33, 20.50s/it] 38%|███▊ | 720/1875 [5:04:46<6:34:33, 20.50s/it] 38%|███▊ | 721/1875 [5:05:08<6:46:31, 21.14s/it] 39%|███▊ | 722/1875 [5:05:24<6:16:20, 19.58s/it] 39%|███▊ | 723/1875 [5:05:47<6:34:55, 20.57s/it] 39%|███▊ | 724/1875 [5:06:06<6:24:45, 20.06s/it] 39%|███▊ | 725/1875 [5:06:29<6:42:54, 21.02s/it] 39%|███▊ | 726/1875 [5:06:47<6:25:36, 20.14s/it] 39%|███▉ | 727/1875 [5:07:07<6:22:13, 19.98s/it] 39%|███▉ | 728/1875 [5:07:31<6:47:49, 21.33s/it] 39%|███▉ | 729/1875 [5:07:51<6:37:43, 20.82s/it] 39%|███▉ | 730/1875 [5:08:16<7:03:43, 22.20s/it] 39%|███▉ | 730/1875 [5:08:16<7:03:43, 22.20s/it] 39%|███▉ | 731/1875 [5:08:35<6:41:26, 21.05s/it] 39%|███▉ | 732/1875 [5:08:55<6:36:44, 20.83s/it] 39%|███▉ | 733/1875 [5:09:18<6:48:30, 21.46s/it] 39%|███▉ | 734/1875 [5:09:37<6:35:18, 20.79s/it] 39%|███▉ | 735/1875 [5:10:05<7:12:39, 22.77s/it] 39%|███▉ | 736/1875 [5:10:25<6:59:07, 22.08s/it] 39%|███▉ | 737/1875 [5:10:45<6:46:50, 21.45s/it] 39%|███▉ | 738/1875 [5:11:04<6:30:24, 20.60s/it] 39%|███▉ | 739/1875 [5:11:24<6:28:23, 20.51s/it] 39%|███▉ | 740/1875 [5:11:46<6:34:27, 20.85s/it] 39%|███▉ | 740/1875 [5:11:46<6:34:27, 20.85s/it] 40%|███▉ | 741/1875 [5:12:08<6:40:54, 21.21s/it] 40%|███▉ | 742/1875 [5:12:24<6:11:00, 19.65s/it] 40%|███▉ | 743/1875 [5:12:44<6:12:03, 19.72s/it] 40%|███▉ | 744/1875 [5:13:07<6:35:04, 20.96s/it] 40%|███▉ | 745/1875 [5:13:32<6:52:25, 21.90s/it] 40%|███▉ | 746/1875 [5:13:55<7:03:19, 22.50s/it] 40%|███▉ | 747/1875 [5:14:22<7:23:08, 23.57s/it] 40%|███▉ | 748/1875 [5:14:40<6:53:07, 21.99s/it] 40%|███▉ | 749/1875 [5:14:59<6:35:08, 21.06s/it] 40%|████ | 750/1875 [5:15:19<6:30:03, 20.80s/it] 40%|████ | 750/1875 [5:15:19<6:30:03, 20.80s/it] 40%|████ | 751/1875 [5:15:42<6:43:41, 21.55s/it] 40%|████ | 752/1875 [5:16:06<6:58:26, 22.36s/it] 40%|████ | 753/1875 [5:16:28<6:53:03, 22.09s/it] 40%|████ | 754/1875 [5:16:46<6:30:34, 20.90s/it] 40%|████ | 755/1875 [5:17:04<6:11:42, 19.91s/it] 40%|████ | 756/1875 [5:17:21<5:57:40, 19.18s/it] 40%|████ | 757/1875 [5:17:46<6:30:24, 20.95s/it] 40%|████ | 758/1875 [5:18:07<6:31:11, 21.01s/it] 40%|████ | 759/1875 [5:18:25<6:13:13, 20.07s/it] 41%|████ | 760/1875 [5:18:43<5:57:57, 19.26s/it] 41%|████ | 760/1875 [5:18:43<5:57:57, 19.26s/it] 41%|████ | 761/1875 [5:19:03<6:02:28, 19.52s/it] 41%|████ | 762/1875 [5:19:29<6:37:57, 21.45s/it] 41%|████ | 763/1875 [5:19:52<6:49:13, 22.08s/it] 41%|████ | 764/1875 [5:20:15<6:55:24, 22.43s/it] 41%|████ | 765/1875 [5:20:35<6:38:54, 21.56s/it] 41%|████ | 766/1875 [5:20:53<6:19:19, 20.52s/it] 41%|████ | 767/1875 [5:21:10<5:57:29, 19.36s/it] 41%|████ | 768/1875 [5:21:33<6:17:11, 20.44s/it] 41%|████ | 769/1875 [5:21:56<6:34:45, 21.42s/it] 41%|████ | 770/1875 [5:22:21<6:53:07, 22.43s/it] 41%|████ | 770/1875 [5:22:21<6:53:07, 22.43s/it] 41%|████ | 771/1875 [5:22:44<6:53:16, 22.46s/it] 41%|████ | 772/1875 [5:23:10<7:16:22, 23.74s/it] 41%|████ | 773/1875 [5:23:35<7:19:25, 23.93s/it] 41%|████▏ | 774/1875 [5:24:04<7:48:53, 25.55s/it] 41%|████▏ | 775/1875 [5:24:23<7:10:54, 23.50s/it] 41%|████▏ | 776/1875 [5:24:44<6:59:58, 22.93s/it] 41%|████▏ | 777/1875 [5:25:08<7:05:23, 23.25s/it] 41%|████▏ | 778/1875 [5:25:27<6:40:05, 21.88s/it] 42%|████▏ | 779/1875 [5:25:52<6:56:10, 22.78s/it] 42%|████▏ | 780/1875 [5:26:13<6:47:06, 22.31s/it] 42%|████▏ | 780/1875 [5:26:13<6:47:06, 22.31s/it] 42%|████▏ | 781/1875 [5:26:30<6:13:59, 20.51s/it] 42%|████▏ | 782/1875 [5:26:54<6:36:42, 21.78s/it] 42%|████▏ | 783/1875 [5:27:17<6:38:46, 21.91s/it] 42%|████▏ | 784/1875 [5:27:41<6:54:53, 22.82s/it] 42%|████▏ | 785/1875 [5:28:01<6:35:50, 21.79s/it] 42%|████▏ | 786/1875 [5:28:20<6:23:39, 21.14s/it] 42%|████▏ | 787/1875 [5:28:43<6:32:09, 21.63s/it] 42%|████▏ | 788/1875 [5:29:11<7:03:15, 23.36s/it] 42%|████▏ | 789/1875 [5:29:27<6:27:02, 21.38s/it] 42%|████▏ | 790/1875 [5:29:45<6:05:45, 20.23s/it] 42%|████▏ | 790/1875 [5:29:45<6:05:45, 20.23s/it] 42%|████▏ | 791/1875 [5:30:05<6:06:28, 20.28s/it] 42%|████▏ | 792/1875 [5:30:24<5:59:04, 19.89s/it] 42%|████▏ | 793/1875 [5:30:50<6:30:37, 21.66s/it] 42%|████▏ | 794/1875 [5:31:10<6:21:25, 21.17s/it] 42%|████▏ | 795/1875 [5:31:29<6:06:48, 20.38s/it] 42%|████▏ | 796/1875 [5:31:50<6:12:57, 20.74s/it] 43%|████▎ | 797/1875 [5:32:19<6:55:59, 23.15s/it] 43%|████▎ | 798/1875 [5:32:41<6:46:38, 22.65s/it] 43%|████▎ | 799/1875 [5:32:58<6:20:49, 21.24s/it] 43%|████▎ | 800/1875 [5:33:20<6:19:28, 21.18s/it] 43%|████▎ | 800/1875 [5:33:20<6:19:28, 21.18s/it]{'loss': '0.01373', 'grad_norm': '4.312', 'learning_rate': '5e-06', 'num_tokens': '2.354e+07', 'completions/mean_length': '12.2', 'completions/min_length': '7', 'completions/max_length': '24.1', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '12.2', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '24.1', 'rewards/reward_function/mean': '0.01486', 'rewards/reward_function/std': '0.6954', 'reward': '0.01486', 'reward_std': '0.6954', 'frac_reward_zero_std': '0.6562', 'kl': '0.2933', 'entropy': '0.1155', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '19.05', 'epoch': '1.136'} +{'loss': '0.01392', 'grad_norm': '6.938', 'learning_rate': '5e-06', 'num_tokens': '2.388e+07', 'completions/mean_length': '13.75', 'completions/min_length': '7', 'completions/max_length': '29.5', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '13.75', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '29.5', 'rewards/reward_function/mean': '0.09738', 'rewards/reward_function/std': '0.7002', 'reward': '0.09738', 'reward_std': '0.7002', 'frac_reward_zero_std': '0.6188', 'kl': '0.2645', 'entropy': '0.1361', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '19.96', 'epoch': '1.152'} +{'loss': '0.01506', 'grad_norm': '11', 'learning_rate': '5e-06', 'num_tokens': '2.42e+07', 'completions/mean_length': '13.14', 'completions/min_length': '7', 'completions/max_length': '32.2', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '13.14', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '32.2', 'rewards/reward_function/mean': '0.1005', 'rewards/reward_function/std': '0.7025', 'reward': '0.1005', 'reward_std': '0.7025', 'frac_reward_zero_std': '0.5687', 'kl': '0.2478', 'entropy': '0.1702', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '21', 'epoch': '1.168'} +{'loss': '0.00784', 'grad_norm': '11.19', 'learning_rate': '5e-06', 'num_tokens': '2.455e+07', 'completions/mean_length': '12.84', 'completions/min_length': '7', 'completions/max_length': '34.5', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '12.84', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '34.5', 'rewards/reward_function/mean': '-0.04709', 'rewards/reward_function/std': '0.6976', 'reward': '-0.04709', 'reward_std': '0.6976', 'frac_reward_zero_std': '0.5875', 'kl': '0.3239', 'entropy': '0.1533', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '20.84', 'epoch': '1.184'} +{'loss': '0.02753', 'grad_norm': '9.938', 'learning_rate': '5e-06', 'num_tokens': '2.487e+07', 'completions/mean_length': '13.62', 'completions/min_length': '7', 'completions/max_length': '47.5', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '13.62', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '47.5', 'rewards/reward_function/mean': '0.05533', 'rewards/reward_function/std': '0.7101', 'reward': '0.05533', 'reward_std': '0.7101', 'frac_reward_zero_std': '0.6937', 'kl': '0.2711', 'entropy': '0.1417', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '21.24', 'epoch': '1.2'} +{'loss': '0.01706', 'grad_norm': '27.75', 'learning_rate': '5e-06', 'num_tokens': '2.518e+07', 'completions/mean_length': '12.26', 'completions/min_length': '7', 'completions/max_length': '30.3', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '12.26', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '30.3', 'rewards/reward_function/mean': '-0.01604', 'rewards/reward_function/std': '0.7219', 'reward': '-0.01604', 'reward_std': '0.7219', 'frac_reward_zero_std': '0.65', 'kl': '0.3838', 'entropy': '0.1524', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '20.29', 'epoch': '1.216'} +{'loss': '0.01725', 'grad_norm': '8.875', 'learning_rate': '5e-06', 'num_tokens': '2.551e+07', 'completions/mean_length': '12.91', 'completions/min_length': '7', 'completions/max_length': '37.1', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '12.91', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '37.1', 'rewards/reward_function/mean': '0.06305', 'rewards/reward_function/std': '0.6908', 'reward': '0.06305', 'reward_std': '0.6908', 'frac_reward_zero_std': '0.5687', 'kl': '0.2727', 'entropy': '0.152', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '21.78', 'epoch': '1.232'} +{'loss': '0.04639', 'grad_norm': '7.062', 'learning_rate': '5e-06', 'num_tokens': '2.587e+07', 'completions/mean_length': '14.27', 'completions/min_length': '7', 'completions/max_length': '43.8', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '14.27', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '43.8', 'rewards/reward_function/mean': '0.0992', 'rewards/reward_function/std': '0.6961', 'reward': '0.0992', 'reward_std': '0.6961', 'frac_reward_zero_std': '0.5938', 'kl': '0.2647', 'entropy': '0.1384', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '23.11', 'epoch': '1.248'} +{'loss': '0.01232', 'grad_norm': '7.969', 'learning_rate': '5e-06', 'num_tokens': '2.62e+07', 'completions/mean_length': '13.33', 'completions/min_length': '7', 'completions/max_length': '36.5', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '13.33', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '36.5', 'rewards/reward_function/mean': '0.02043', 'rewards/reward_function/std': '0.7141', 'reward': '0.02043', 'reward_std': '0.7141', 'frac_reward_zero_std': '0.5687', 'kl': '0.2635', 'entropy': '0.15', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '21.08', 'epoch': '1.264'} +{'loss': '0.03148', 'grad_norm': '7.469', 'learning_rate': '5e-06', 'num_tokens': '2.653e+07', 'completions/mean_length': '13.8', 'completions/min_length': '7', 'completions/max_length': '42.7', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '13.8', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '42.7', 'rewards/reward_function/mean': '0.03885', 'rewards/reward_function/std': '0.7014', 'reward': '0.03885', 'reward_std': '0.7014', 'frac_reward_zero_std': '0.6125', 'kl': '0.2335', 'entropy': '0.1653', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '21.37', 'epoch': '1.28'} + 43%|████▎ | 801/1875 [5:34:48<12:23:02, 41.51s/it] 43%|████▎ | 802/1875 [5:35:08<10:24:24, 34.92s/it] 43%|████▎ | 803/1875 [5:35:29<9:09:20, 30.75s/it] 43%|████▎ | 804/1875 [5:35:52<8:29:00, 28.52s/it] 43%|████▎ | 805/1875 [5:36:14<7:49:48, 26.34s/it] 43%|████▎ | 806/1875 [5:36:32<7:07:46, 24.01s/it] 43%|████▎ | 807/1875 [5:36:51<6:39:08, 22.42s/it] 43%|████▎ | 808/1875 [5:37:15<6:49:51, 23.05s/it] 43%|████▎ | 809/1875 [5:37:37<6:43:46, 22.73s/it] 43%|████▎ | 810/1875 [5:38:06<7:14:41, 24.49s/it] 43%|████▎ | 810/1875 [5:38:06<7:14:41, 24.49s/it] 43%|████▎ | 811/1875 [5:38:32<7:22:32, 24.95s/it] 43%|████▎ | 812/1875 [5:38:53<6:58:31, 23.62s/it] 43%|████▎ | 813/1875 [5:39:13<6:39:28, 22.57s/it] 43%|████▎ | 814/1875 [5:39:32<6:21:43, 21.59s/it] 43%|████▎ | 815/1875 [5:39:58<6:47:39, 23.07s/it] 44%|████▎ | 816/1875 [5:40:18<6:26:27, 21.90s/it] 44%|████▎ | 817/1875 [5:40:40<6:29:11, 22.07s/it] 44%|████▎ | 818/1875 [5:40:58<6:08:54, 20.94s/it] 44%|████▎ | 819/1875 [5:41:23<6:29:13, 22.11s/it] 44%|████▎ | 820/1875 [5:41:42<6:13:04, 21.22s/it] 44%|████▎ | 820/1875 [5:41:42<6:13:04, 21.22s/it] 44%|████▍ | 821/1875 [5:41:59<5:49:39, 19.90s/it] 44%|████▍ | 822/1875 [5:42:25<6:18:49, 21.59s/it] 44%|████▍ | 823/1875 [5:42:46<6:16:18, 21.46s/it] 44%|████▍ | 824/1875 [5:43:03<5:51:49, 20.09s/it] 44%|████▍ | 825/1875 [5:43:20<5:34:42, 19.13s/it] 44%|████▍ | 826/1875 [5:43:43<5:58:00, 20.48s/it] 44%|████▍ | 827/1875 [5:44:04<6:01:16, 20.68s/it] 44%|████▍ | 828/1875 [5:44:21<5:41:24, 19.57s/it] 44%|████▍ | 829/1875 [5:44:42<5:44:16, 19.75s/it] 44%|████▍ | 830/1875 [5:45:00<5:35:14, 19.25s/it] 44%|████▍ | 830/1875 [5:45:00<5:35:14, 19.25s/it] 44%|████▍ | 831/1875 [5:45:23<5:53:54, 20.34s/it] 44%|████▍ | 832/1875 [5:45:43<5:53:00, 20.31s/it] 44%|████▍ | 833/1875 [5:46:08<6:19:17, 21.84s/it] 44%|████▍ | 834/1875 [5:46:27<6:04:54, 21.03s/it] 45%|████▍ | 835/1875 [5:46:50<6:11:27, 21.43s/it] 45%|████▍ | 836/1875 [5:47:09<5:59:08, 20.74s/it] 45%|████▍ | 837/1875 [5:47:30<6:01:52, 20.92s/it] 45%|████▍ | 838/1875 [5:47:54<6:17:05, 21.82s/it] 45%|████▍ | 839/1875 [5:48:13<6:01:07, 20.92s/it] 45%|████▍ | 840/1875 [5:48:30<5:39:37, 19.69s/it] 45%|████▍ | 840/1875 [5:48:30<5:39:37, 19.69s/it] 45%|████▍ | 841/1875 [5:48:49<5:38:31, 19.64s/it] 45%|████▍ | 842/1875 [5:49:13<6:00:19, 20.93s/it] 45%|████▍ | 843/1875 [5:49:37<6:15:22, 21.82s/it] 45%|████▌ | 844/1875 [5:49:57<6:07:08, 21.37s/it] 45%|████▌ | 845/1875 [5:50:20<6:15:33, 21.88s/it] 45%|████▌ | 846/1875 [5:50:45<6:28:27, 22.65s/it] 45%|████▌ | 847/1875 [5:51:06<6:20:28, 22.21s/it] 45%|████▌ | 848/1875 [5:51:27<6:15:02, 21.91s/it] 45%|████▌ | 849/1875 [5:51:45<5:51:36, 20.56s/it] 45%|████▌ | 850/1875 [5:52:09<6:12:04, 21.78s/it] 45%|████▌ | 850/1875 [5:52:09<6:12:04, 21.78s/it] 45%|████▌ | 851/1875 [5:52:33<6:21:31, 22.36s/it] 45%|████▌ | 852/1875 [5:52:56<6:21:50, 22.40s/it] 45%|████▌ | 853/1875 [5:53:18<6:19:40, 22.29s/it] 46%|████▌ | 854/1875 [5:53:42<6:32:26, 23.06s/it] 46%|████▌ | 855/1875 [5:54:01<6:07:24, 21.61s/it] 46%|████▌ | 856/1875 [5:54:18<5:43:20, 20.22s/it] 46%|████▌ | 857/1875 [5:54:40<5:55:46, 20.97s/it] 46%|████▌ | 858/1875 [5:55:00<5:46:40, 20.45s/it] 46%|████▌ | 859/1875 [5:55:20<5:44:17, 20.33s/it] 46%|████▌ | 860/1875 [5:55:40<5:45:08, 20.40s/it] 46%|████▌ | 860/1875 [5:55:40<5:45:08, 20.40s/it] 46%|████▌ | 861/1875 [5:56:07<6:15:06, 22.20s/it] 46%|████▌ | 862/1875 [5:56:26<6:01:31, 21.41s/it] 46%|████▌ | 863/1875 [5:56:51<6:16:31, 22.32s/it] 46%|████▌ | 864/1875 [5:57:14<6:22:35, 22.71s/it] 46%|████▌ | 865/1875 [5:57:29<5:42:14, 20.33s/it] 46%|████▌ | 866/1875 [5:57:52<5:52:56, 20.99s/it] 46%|████▌ | 867/1875 [5:58:10<5:37:58, 20.12s/it] 46%|████▋ | 868/1875 [5:58:35<6:02:58, 21.63s/it] 46%|████▋ | 869/1875 [5:58:52<5:41:04, 20.34s/it] 46%|████▋ | 870/1875 [5:59:15<5:55:21, 21.22s/it] 46%|████▋ | 870/1875 [5:59:15<5:55:21, 21.22s/it] 46%|████▋ | 871/1875 [5:59:39<6:08:39, 22.03s/it] 47%|████▋ | 872/1875 [5:59:58<5:49:05, 20.88s/it] 47%|████▋ | 873/1875 [6:00:20<5:56:54, 21.37s/it] 47%|████▋ | 874/1875 [6:00:41<5:55:13, 21.29s/it] 47%|████▋ | 875/1875 [6:01:05<6:05:15, 21.92s/it] 47%|████▋ | 876/1875 [6:01:27<6:09:49, 22.21s/it] 47%|████▋ | 877/1875 [6:01:47<5:57:51, 21.51s/it] 47%|████▋ | 878/1875 [6:02:05<5:38:20, 20.36s/it] 47%|████▋ | 879/1875 [6:02:30<6:02:34, 21.84s/it] 47%|████▋ | 880/1875 [6:02:48<5:42:18, 20.64s/it] 47%|████▋ | 880/1875 [6:02:48<5:42:18, 20.64s/it] 47%|████▋ | 881/1875 [6:03:11<5:53:30, 21.34s/it] 47%|████▋ | 882/1875 [6:03:31<5:48:32, 21.06s/it] 47%|████▋ | 883/1875 [6:03:54<5:54:57, 21.47s/it] 47%|████▋ | 884/1875 [6:04:17<6:04:59, 22.10s/it] 47%|████▋ | 885/1875 [6:04:41<6:09:25, 22.39s/it] 47%|████▋ | 886/1875 [6:04:58<5:45:58, 20.99s/it] 47%|████▋ | 887/1875 [6:05:21<5:54:47, 21.55s/it] 47%|████▋ | 888/1875 [6:05:44<5:59:27, 21.85s/it] 47%|████▋ | 889/1875 [6:06:07<6:04:25, 22.18s/it] 47%|████▋ | 890/1875 [6:06:25<5:43:57, 20.95s/it] 47%|████▋ | 890/1875 [6:06:25<5:43:57, 20.95s/it] 48%|████▊ | 891/1875 [6:06:49<6:01:54, 22.07s/it] 48%|████▊ | 892/1875 [6:07:08<5:45:36, 21.10s/it] 48%|████▊ | 893/1875 [6:07:29<5:42:23, 20.92s/it] 48%|████▊ | 894/1875 [6:07:49<5:39:30, 20.76s/it] 48%|████▊ | 895/1875 [6:08:08<5:31:19, 20.29s/it] 48%|████▊ | 896/1875 [6:08:29<5:34:48, 20.52s/it] 48%|████▊ | 897/1875 [6:08:51<5:38:43, 20.78s/it] 48%|████▊ | 898/1875 [6:09:15<5:56:36, 21.90s/it] 48%|████▊ | 899/1875 [6:09:33<5:38:12, 20.79s/it] 48%|████▊ | 900/1875 [6:09:59<5:59:04, 22.10s/it] 48%|████▊ | 900/1875 [6:09:59<5:59:04, 22.10s/it]{'loss': '0.01215', 'grad_norm': '13.38', 'learning_rate': '5e-06', 'num_tokens': '2.685e+07', 'completions/mean_length': '13.39', 'completions/min_length': '7', 'completions/max_length': '45.6', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '13.39', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '45.6', 'rewards/reward_function/mean': '-0.004922', 'rewards/reward_function/std': '0.7225', 'reward': '-0.004922', 'reward_std': '0.7225', 'frac_reward_zero_std': '0.575', 'kl': '0.2708', 'entropy': '0.1764', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '21.92', 'epoch': '1.296'} +{'loss': '0.02334', 'grad_norm': '12.62', 'learning_rate': '5e-06', 'num_tokens': '2.72e+07', 'completions/mean_length': '14.15', 'completions/min_length': '7', 'completions/max_length': '38.7', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '14.15', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '38.7', 'rewards/reward_function/mean': '0.09831', 'rewards/reward_function/std': '0.692', 'reward': '0.09831', 'reward_std': '0.692', 'frac_reward_zero_std': '0.65', 'kl': '0.2292', 'entropy': '0.1498', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '21.55', 'epoch': '1.312'} +{'loss': '0.01817', 'grad_norm': '15.06', 'learning_rate': '5e-06', 'num_tokens': '2.75e+07', 'completions/mean_length': '12.17', 'completions/min_length': '7', 'completions/max_length': '28.9', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '12.17', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '28.9', 'rewards/reward_function/mean': '0.02458', 'rewards/reward_function/std': '0.7055', 'reward': '0.02458', 'reward_std': '0.7055', 'frac_reward_zero_std': '0.6937', 'kl': '0.2565', 'entropy': '0.1424', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '19.65', 'epoch': '1.328'} +{'loss': '0.04448', 'grad_norm': '8.062', 'learning_rate': '5e-06', 'num_tokens': '2.782e+07', 'completions/mean_length': '13.55', 'completions/min_length': '7', 'completions/max_length': '34.6', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '13.55', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '34.6', 'rewards/reward_function/mean': '0.0596', 'rewards/reward_function/std': '0.7048', 'reward': '0.0596', 'reward_std': '0.7048', 'frac_reward_zero_std': '0.5437', 'kl': '0.3051', 'entropy': '0.1581', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '20.92', 'epoch': '1.344'} +{'loss': '0.02221', 'grad_norm': '9.25', 'learning_rate': '5e-06', 'num_tokens': '2.817e+07', 'completions/mean_length': '14.04', 'completions/min_length': '7', 'completions/max_length': '35.4', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '14.04', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '35.4', 'rewards/reward_function/mean': '0.07037', 'rewards/reward_function/std': '0.6958', 'reward': '0.07037', 'reward_std': '0.6958', 'frac_reward_zero_std': '0.5938', 'kl': '0.2594', 'entropy': '0.1511', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '21.87', 'epoch': '1.36'} +{'loss': '0.03586', 'grad_norm': '5.312', 'learning_rate': '5e-06', 'num_tokens': '2.851e+07', 'completions/mean_length': '13.77', 'completions/min_length': '7', 'completions/max_length': '32.5', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '13.77', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '32.5', 'rewards/reward_function/mean': '0.02103', 'rewards/reward_function/std': '0.7268', 'reward': '0.02103', 'reward_std': '0.7268', 'frac_reward_zero_std': '0.6312', 'kl': '0.26', 'entropy': '0.1455', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '21', 'epoch': '1.376'} +{'loss': '0.01759', 'grad_norm': '8.562', 'learning_rate': '5e-06', 'num_tokens': '2.885e+07', 'completions/mean_length': '13.1', 'completions/min_length': '7', 'completions/max_length': '30.9', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '13.1', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '30.9', 'rewards/reward_function/mean': '0.02683', 'rewards/reward_function/std': '0.724', 'reward': '0.02683', 'reward_std': '0.724', 'frac_reward_zero_std': '0.6375', 'kl': '0.2306', 'entropy': '0.1256', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '21.43', 'epoch': '1.392'} +{'loss': '0.0188', 'grad_norm': '5.312', 'learning_rate': '5e-06', 'num_tokens': '2.918e+07', 'completions/mean_length': '13.21', 'completions/min_length': '7', 'completions/max_length': '34.4', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '13.21', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '34.4', 'rewards/reward_function/mean': '0.07696', 'rewards/reward_function/std': '0.6649', 'reward': '0.07696', 'reward_std': '0.6649', 'frac_reward_zero_std': '0.6875', 'kl': '0.2565', 'entropy': '0.157', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '21.19', 'epoch': '1.408'} +{'loss': '0.04949', 'grad_norm': '9.25', 'learning_rate': '5e-06', 'num_tokens': '2.951e+07', 'completions/mean_length': '13.27', 'completions/min_length': '7', 'completions/max_length': '40', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '13.27', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '40', 'rewards/reward_function/mean': '-0.003029', 'rewards/reward_function/std': '0.6885', 'reward': '-0.003029', 'reward_std': '0.6885', 'frac_reward_zero_std': '0.7', 'kl': '0.2933', 'entropy': '0.2583', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '21.57', 'epoch': '1.424'} +{'loss': '0.02414', 'grad_norm': '125', 'learning_rate': '5e-06', 'num_tokens': '2.984e+07', 'completions/mean_length': '14.38', 'completions/min_length': '7', 'completions/max_length': '40.2', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '14.38', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '40.2', 'rewards/reward_function/mean': '0.1392', 'rewards/reward_function/std': '0.6645', 'reward': '0.1392', 'reward_std': '0.6645', 'frac_reward_zero_std': '0.6375', 'kl': '0.2777', 'entropy': '0.1298', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '21.31', 'epoch': '1.44'} + 48%|████▊ | 901/1875 [6:11:26<11:17:19, 41.72s/it] 48%|████▊ | 902/1875 [6:11:51<9:56:55, 36.81s/it] 48%|████▊ | 903/1875 [6:12:13<8:40:41, 32.14s/it] 48%|████▊ | 904/1875 [6:12:31<7:32:26, 27.96s/it] 48%|████▊ | 905/1875 [6:12:52<7:00:09, 25.99s/it] 48%|████▊ | 906/1875 [6:13:13<6:36:05, 24.53s/it] 48%|████▊ | 907/1875 [6:13:30<5:58:01, 22.19s/it] 48%|████▊ | 908/1875 [6:13:49<5:41:36, 21.20s/it] 48%|████▊ | 909/1875 [6:14:08<5:32:18, 20.64s/it] 49%|████▊ | 910/1875 [6:14:25<5:13:42, 19.51s/it] 49%|████▊ | 910/1875 [6:14:25<5:13:42, 19.51s/it] 49%|████▊ | 911/1875 [6:14:47<5:22:42, 20.09s/it] 49%|████▊ | 912/1875 [6:15:07<5:25:19, 20.27s/it] 49%|████▊ | 913/1875 [6:15:25<5:11:04, 19.40s/it] 49%|████▊ | 914/1875 [6:15:42<5:01:53, 18.85s/it] 49%|████▉ | 915/1875 [6:16:00<4:56:38, 18.54s/it] 49%|████▉ | 916/1875 [6:16:22<5:12:38, 19.56s/it] 49%|████▉ | 917/1875 [6:16:39<4:58:55, 18.72s/it] 49%|████▉ | 918/1875 [6:17:03<5:22:37, 20.23s/it] 49%|████▉ | 919/1875 [6:17:22<5:16:06, 19.84s/it] 49%|████▉ | 920/1875 [6:17:45<5:31:13, 20.81s/it] 49%|████▉ | 920/1875 [6:17:45<5:31:13, 20.81s/it] 49%|████▉ | 921/1875 [6:18:07<5:37:53, 21.25s/it] 49%|████▉ | 922/1875 [6:18:27<5:34:12, 21.04s/it] 49%|████▉ | 923/1875 [6:18:45<5:15:15, 19.87s/it] 49%|████▉ | 924/1875 [6:19:07<5:25:59, 20.57s/it] 49%|████▉ | 925/1875 [6:19:29<5:31:57, 20.97s/it] 49%|████▉ | 926/1875 [6:19:47<5:18:12, 20.12s/it] 49%|████▉ | 927/1875 [6:20:06<5:14:14, 19.89s/it] 49%|████▉ | 928/1875 [6:20:30<5:32:20, 21.06s/it] 50%|████▉ | 929/1875 [6:20:48<5:16:42, 20.09s/it] 50%|████▉ | 930/1875 [6:21:09<5:23:24, 20.53s/it] 50%|████▉ | 930/1875 [6:21:09<5:23:24, 20.53s/it] 50%|████▉ | 931/1875 [6:21:32<5:35:04, 21.30s/it] 50%|████▉ | 932/1875 [6:21:50<5:19:39, 20.34s/it] 50%|████▉ | 933/1875 [6:22:12<5:23:04, 20.58s/it] 50%|████▉ | 934/1875 [6:22:32<5:21:44, 20.51s/it] 50%|████▉ | 935/1875 [6:22:50<5:08:28, 19.69s/it] 50%|████▉ | 936/1875 [6:23:06<4:51:46, 18.64s/it] 50%|████▉ | 937/1875 [6:23:27<5:01:00, 19.25s/it] 50%|█████ | 938/1875 [6:23:49<5:16:40, 20.28s/it] 50%|█████ | 939/1875 [6:24:09<5:12:57, 20.06s/it] 50%|█████ | 940/1875 [6:24:33<5:31:30, 21.27s/it] 50%|█████ | 940/1875 [6:24:33<5:31:30, 21.27s/it] 50%|█████ | 941/1875 [6:24:53<5:27:03, 21.01s/it] 50%|█████ | 942/1875 [6:25:15<5:30:03, 21.23s/it] 50%|█████ | 943/1875 [6:25:38<5:37:54, 21.75s/it] 50%|█████ | 944/1875 [6:26:00<5:38:22, 21.81s/it] 50%|█████ | 945/1875 [6:26:23<5:41:39, 22.04s/it] 50%|█████ | 946/1875 [6:26:42<5:28:27, 21.21s/it] 51%|█████ | 947/1875 [6:27:01<5:16:23, 20.46s/it] 51%|█████ | 948/1875 [6:27:22<5:18:56, 20.64s/it] 51%|█████ | 949/1875 [6:27:40<5:07:45, 19.94s/it] 51%|█████ | 950/1875 [6:27:59<5:01:27, 19.55s/it] 51%|█████ | 950/1875 [6:27:59<5:01:27, 19.55s/it] 51%|█████ | 951/1875 [6:28:26<5:35:09, 21.76s/it] 51%|█████ | 952/1875 [6:28:44<5:19:23, 20.76s/it] 51%|█████ | 953/1875 [6:29:04<5:16:46, 20.61s/it] 51%|█████ | 954/1875 [6:29:23<5:09:04, 20.14s/it] 51%|█████ | 955/1875 [6:29:45<5:18:01, 20.74s/it] 51%|█████ | 956/1875 [6:30:05<5:13:13, 20.45s/it] 51%|█████ | 957/1875 [6:30:23<5:02:52, 19.80s/it] 51%|█████ | 958/1875 [6:30:41<4:51:20, 19.06s/it] 51%|█████ | 959/1875 [6:31:01<4:54:01, 19.26s/it] 51%|█████ | 960/1875 [6:31:24<5:12:26, 20.49s/it] 51%|█████ | 960/1875 [6:31:24<5:12:26, 20.49s/it] 51%|█████▏ | 961/1875 [6:31:49<5:31:28, 21.76s/it] 51%|█████▏ | 962/1875 [6:32:08<5:20:02, 21.03s/it] 51%|█████▏ | 963/1875 [6:32:27<5:11:30, 20.49s/it] 51%|█████▏ | 964/1875 [6:32:56<5:49:42, 23.03s/it] 51%|█████▏ | 965/1875 [6:33:20<5:54:42, 23.39s/it] 52%|█████▏ | 966/1875 [6:33:42<5:48:04, 22.98s/it] 52%|█████▏ | 967/1875 [6:34:03<5:38:51, 22.39s/it] 52%|█████▏ | 968/1875 [6:34:31<6:01:45, 23.93s/it] 52%|█████▏ | 969/1875 [6:34:53<5:50:49, 23.23s/it] 52%|█████▏ | 970/1875 [6:35:12<5:31:40, 21.99s/it] 52%|█████▏ | 970/1875 [6:35:12<5:31:40, 21.99s/it] 52%|█████▏ | 971/1875 [6:35:35<5:35:44, 22.28s/it] 52%|█████▏ | 972/1875 [6:35:56<5:32:11, 22.07s/it] 52%|█████▏ | 973/1875 [6:36:16<5:23:10, 21.50s/it] 52%|█████▏ | 974/1875 [6:36:42<5:42:14, 22.79s/it] 52%|█████▏ | 975/1875 [6:37:05<5:42:37, 22.84s/it] 52%|█████▏ | 976/1875 [6:37:28<5:41:31, 22.79s/it] 52%|█████▏ | 977/1875 [6:37:50<5:39:00, 22.65s/it] 52%|█████▏ | 978/1875 [6:38:13<5:39:20, 22.70s/it] 52%|█████▏ | 979/1875 [6:38:29<5:07:18, 20.58s/it] 52%|█████▏ | 980/1875 [6:38:53<5:23:55, 21.72s/it] 52%|█████▏ | 980/1875 [6:38:53<5:23:55, 21.72s/it] 52%|█████▏ | 981/1875 [6:39:12<5:11:35, 20.91s/it] 52%|█████▏ | 982/1875 [6:39:33<5:12:32, 21.00s/it] 52%|█████▏ | 983/1875 [6:39:54<5:10:19, 20.87s/it] 52%|█████▏ | 984/1875 [6:40:13<5:02:24, 20.36s/it] 53%|█████▎ | 985/1875 [6:40:38<5:23:53, 21.84s/it] 53%|█████▎ | 986/1875 [6:41:01<5:28:06, 22.14s/it] 53%|█████▎ | 987/1875 [6:41:21<5:19:06, 21.56s/it] 53%|█████▎ | 988/1875 [6:41:44<5:22:32, 21.82s/it] 53%|█████▎ | 989/1875 [6:42:07<5:28:16, 22.23s/it] 53%|█████▎ | 990/1875 [6:42:24<5:04:37, 20.65s/it] 53%|█████▎ | 990/1875 [6:42:24<5:04:37, 20.65s/it] 53%|█████▎ | 991/1875 [6:42:43<4:56:45, 20.14s/it] 53%|█████▎ | 992/1875 [6:43:03<4:57:08, 20.19s/it] 53%|█████▎ | 993/1875 [6:43:25<5:06:27, 20.85s/it] 53%|█████▎ | 994/1875 [6:43:46<5:03:02, 20.64s/it] 53%|█████▎ | 995/1875 [6:44:09<5:13:33, 21.38s/it] 53%|█████▎ | 996/1875 [6:44:27<4:58:20, 20.36s/it] 53%|█████▎ | 997/1875 [6:44:45<4:47:46, 19.67s/it] 53%|█████▎ | 998/1875 [6:45:10<5:10:59, 21.28s/it] 53%|█████▎ | 999/1875 [6:45:25<4:46:19, 19.61s/it] 53%|█████▎ | 1000/1875 [6:45:45<4:45:56, 19.61s/it] 53%|█████▎ | 1000/1875 [6:45:45<4:45:56, 19.61s/it]{'loss': '0.02703', 'grad_norm': '15.25', 'learning_rate': '5e-06', 'num_tokens': '3.015e+07', 'completions/mean_length': '12.93', 'completions/min_length': '7', 'completions/max_length': '31.4', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '12.93', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '31.4', 'rewards/reward_function/mean': '0.09445', 'rewards/reward_function/std': '0.701', 'reward': '0.09445', 'reward_std': '0.701', 'frac_reward_zero_std': '0.625', 'kl': '0.2366', 'entropy': '0.1395', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '20.13', 'epoch': '1.456'} +{'loss': '-0.002355', 'grad_norm': '18.25', 'learning_rate': '5e-06', 'num_tokens': '3.047e+07', 'completions/mean_length': '12.9', 'completions/min_length': '7', 'completions/max_length': '33.6', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '12.9', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '33.6', 'rewards/reward_function/mean': '0.07196', 'rewards/reward_function/std': '0.7268', 'reward': '0.07196', 'reward_std': '0.7268', 'frac_reward_zero_std': '0.6562', 'kl': '0.2557', 'entropy': '0.1253', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '19.85', 'epoch': '1.472'} +{'loss': '0.03109', 'grad_norm': '5.562', 'learning_rate': '5e-06', 'num_tokens': '3.079e+07', 'completions/mean_length': '12.35', 'completions/min_length': '7', 'completions/max_length': '32.5', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '12.35', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '32.5', 'rewards/reward_function/mean': '-0.01535', 'rewards/reward_function/std': '0.7508', 'reward': '-0.01535', 'reward_std': '0.7508', 'frac_reward_zero_std': '0.6562', 'kl': '0.2845', 'entropy': '0.152', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '20.36', 'epoch': '1.488'} +{'loss': '0.01116', 'grad_norm': '13.5', 'learning_rate': '5e-06', 'num_tokens': '3.112e+07', 'completions/mean_length': '12.59', 'completions/min_length': '7', 'completions/max_length': '27.8', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '12.59', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '27.8', 'rewards/reward_function/mean': '0.02258', 'rewards/reward_function/std': '0.7219', 'reward': '0.02258', 'reward_std': '0.7219', 'frac_reward_zero_std': '0.6375', 'kl': '0.3145', 'entropy': '0.1267', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '20.28', 'epoch': '1.504'} +{'loss': '0.01967', 'grad_norm': '15.62', 'learning_rate': '5e-06', 'num_tokens': '3.145e+07', 'completions/mean_length': '13.54', 'completions/min_length': '7', 'completions/max_length': '31.9', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '13.54', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '31.9', 'rewards/reward_function/mean': '0.07259', 'rewards/reward_function/std': '0.7198', 'reward': '0.07259', 'reward_std': '0.7198', 'frac_reward_zero_std': '0.6062', 'kl': '0.2934', 'entropy': '0.1336', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '20.48', 'epoch': '1.52'} +{'loss': '0.02238', 'grad_norm': '10.56', 'learning_rate': '5e-06', 'num_tokens': '3.179e+07', 'completions/mean_length': '13.37', 'completions/min_length': '7', 'completions/max_length': '32.2', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '13.37', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '32.2', 'rewards/reward_function/mean': '0.1273', 'rewards/reward_function/std': '0.6514', 'reward': '0.1273', 'reward_std': '0.6514', 'frac_reward_zero_std': '0.5687', 'kl': '0.326', 'entropy': '0.135', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '20.44', 'epoch': '1.536'} +{'loss': '0.0301', 'grad_norm': '12.88', 'learning_rate': '5e-06', 'num_tokens': '3.213e+07', 'completions/mean_length': '13.63', 'completions/min_length': '7', 'completions/max_length': '43.8', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '13.63', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '43.8', 'rewards/reward_function/mean': '0.03713', 'rewards/reward_function/std': '0.6989', 'reward': '0.03713', 'reward_std': '0.6989', 'frac_reward_zero_std': '0.6062', 'kl': '0.2511', 'entropy': '0.3191', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '22.66', 'epoch': '1.552'} +{'loss': '0.01786', 'grad_norm': '7.094', 'learning_rate': '5e-06', 'num_tokens': '3.245e+07', 'completions/mean_length': '13.18', 'completions/min_length': '7', 'completions/max_length': '37.6', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '13.18', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '37.6', 'rewards/reward_function/mean': '0.07022', 'rewards/reward_function/std': '0.696', 'reward': '0.07022', 'reward_std': '0.696', 'frac_reward_zero_std': '0.6188', 'kl': '0.2531', 'entropy': '0.1578', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '22.04', 'epoch': '1.568'} +{'loss': '0.02319', 'grad_norm': '6.688', 'learning_rate': '5e-06', 'num_tokens': '3.278e+07', 'completions/mean_length': '12.86', 'completions/min_length': '7', 'completions/max_length': '34.4', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '12.86', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '34.4', 'rewards/reward_function/mean': '0.01617', 'rewards/reward_function/std': '0.7148', 'reward': '0.01617', 'reward_std': '0.7148', 'frac_reward_zero_std': '0.6375', 'kl': '0.272', 'entropy': '0.1632', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '21.01', 'epoch': '1.584'} +{'loss': '0.03764', 'grad_norm': '16.12', 'learning_rate': '5e-06', 'num_tokens': '3.308e+07', 'completions/mean_length': '12.66', 'completions/min_length': '7', 'completions/max_length': '39.5', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '12.66', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '39.5', 'rewards/reward_function/mean': '0.02119', 'rewards/reward_function/std': '0.7281', 'reward': '0.02119', 'reward_std': '0.7281', 'frac_reward_zero_std': '0.6438', 'kl': '0.2876', 'entropy': '0.1778', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '20.05', 'epoch': '1.6'} + 53%|█████▎ | 1001/1875 [6:47:09<9:28:09, 39.00s/it] 53%|█████▎ | 1002/1875 [6:47:28<7:58:44, 32.90s/it] 53%|█████▎ | 1003/1875 [6:47:47<6:57:31, 28.73s/it] 54%|█████▎ | 1004/1875 [6:48:07<6:19:40, 26.15s/it] 54%|█████▎ | 1005/1875 [6:48:24<5:39:08, 23.39s/it] 54%|█████▎ | 1006/1875 [6:48:44<5:23:57, 22.37s/it] 54%|█████▎ | 1007/1875 [6:49:05<5:16:10, 21.86s/it] 54%|█████▍ | 1008/1875 [6:49:21<4:52:13, 20.22s/it] 54%|█████▍ | 1009/1875 [6:49:40<4:45:07, 19.75s/it] 54%|█████▍ | 1010/1875 [6:50:01<4:50:37, 20.16s/it] 54%|█████▍ | 1010/1875 [6:50:01<4:50:37, 20.16s/it] 54%|█████▍ | 1011/1875 [6:50:24<5:04:16, 21.13s/it] 54%|█████▍ | 1012/1875 [6:50:45<5:01:45, 20.98s/it] 54%|█████▍ | 1013/1875 [6:51:04<4:52:51, 20.38s/it] 54%|█████▍ | 1014/1875 [6:51:24<4:49:35, 20.18s/it] 54%|█████▍ | 1015/1875 [6:51:43<4:44:00, 19.81s/it] 54%|█████▍ | 1016/1875 [6:52:02<4:43:03, 19.77s/it] 54%|█████▍ | 1017/1875 [6:52:21<4:40:25, 19.61s/it] 54%|█████▍ | 1018/1875 [6:52:44<4:53:24, 20.54s/it] 54%|█████▍ | 1019/1875 [6:53:01<4:37:13, 19.43s/it] 54%|█████▍ | 1020/1875 [6:53:27<5:02:44, 21.24s/it] 54%|█████▍ | 1020/1875 [6:53:27<5:02:44, 21.24s/it] 54%|█████▍ | 1021/1875 [6:53:43<4:42:47, 19.87s/it] 55%|█████▍ | 1022/1875 [6:54:06<4:55:06, 20.76s/it] 55%|█████▍ | 1023/1875 [6:54:26<4:52:59, 20.63s/it] 55%|█████▍ | 1024/1875 [6:54:50<5:03:55, 21.43s/it] 55%|█████▍ | 1025/1875 [6:55:09<4:54:50, 20.81s/it] 55%|█████▍ | 1026/1875 [6:55:27<4:44:36, 20.11s/it] 55%|█████▍ | 1027/1875 [6:55:45<4:35:15, 19.48s/it] 55%|█████▍ | 1028/1875 [6:56:07<4:41:49, 19.96s/it] 55%|█████▍ | 1029/1875 [6:56:30<4:54:59, 20.92s/it] 55%|█████▍ | 1030/1875 [6:56:50<4:51:21, 20.69s/it] 55%|█████▍ | 1030/1875 [6:56:50<4:51:21, 20.69s/it] 55%|█████▍ | 1031/1875 [6:57:14<5:03:35, 21.58s/it] 55%|█████▌ | 1032/1875 [6:57:34<4:56:22, 21.09s/it] 55%|█████▌ | 1033/1875 [6:57:54<4:53:58, 20.95s/it] 55%|█████▌ | 1034/1875 [6:58:14<4:48:15, 20.57s/it] 55%|█████▌ | 1035/1875 [6:58:34<4:44:38, 20.33s/it] 55%|█████▌ | 1036/1875 [6:58:59<5:07:17, 21.98s/it] 55%|█████▌ | 1037/1875 [6:59:24<5:16:31, 22.66s/it] 55%|█████▌ | 1038/1875 [6:59:46<5:13:48, 22.49s/it] 55%|█████▌ | 1039/1875 [7:00:10<5:19:14, 22.91s/it] 55%|█████▌ | 1040/1875 [7:00:27<4:56:38, 21.32s/it] 55%|█████▌ | 1040/1875 [7:00:27<4:56:38, 21.32s/it] 56%|█████▌ | 1041/1875 [7:00:48<4:52:24, 21.04s/it] 56%|█████▌ | 1042/1875 [7:01:12<5:04:54, 21.96s/it] 56%|█████▌ | 1043/1875 [7:01:30<4:48:17, 20.79s/it] 56%|█████▌ | 1044/1875 [7:01:51<4:47:38, 20.77s/it] 56%|█████▌ | 1045/1875 [7:02:08<4:34:28, 19.84s/it] 56%|█████▌ | 1046/1875 [7:02:29<4:38:24, 20.15s/it] 56%|█████▌ | 1047/1875 [7:02:48<4:32:44, 19.76s/it] 56%|█████▌ | 1048/1875 [7:03:06<4:26:49, 19.36s/it] 56%|█████▌ | 1049/1875 [7:03:25<4:23:47, 19.16s/it] 56%|█████▌ | 1050/1875 [7:03:42<4:13:59, 18.47s/it] 56%|█████▌ | 1050/1875 [7:03:42<4:13:59, 18.47s/it] 56%|█████▌ | 1051/1875 [7:04:03<4:22:58, 19.15s/it] 56%|█████▌ | 1052/1875 [7:04:21<4:17:46, 18.79s/it] 56%|█████▌ | 1053/1875 [7:04:40<4:19:59, 18.98s/it] 56%|█████▌ | 1054/1875 [7:05:05<4:43:15, 20.70s/it] 56%|█████▋ | 1055/1875 [7:05:28<4:53:58, 21.51s/it] 56%|█████▋ | 1056/1875 [7:05:52<5:02:09, 22.14s/it] 56%|█████▋ | 1057/1875 [7:06:09<4:43:03, 20.76s/it] 56%|█████▋ | 1058/1875 [7:06:26<4:28:15, 19.70s/it] 56%|█████▋ | 1059/1875 [7:06:47<4:30:51, 19.92s/it] 57%|█████▋ | 1060/1875 [7:07:10<4:43:02, 20.84s/it] 57%|█████▋ | 1060/1875 [7:07:10<4:43:02, 20.84s/it] 57%|█████▋ | 1061/1875 [7:07:28<4:31:14, 19.99s/it] 57%|█████▋ | 1062/1875 [7:07:51<4:41:45, 20.79s/it] 57%|█████▋ | 1063/1875 [7:08:11<4:39:27, 20.65s/it] 57%|█████▋ | 1064/1875 [7:08:35<4:51:54, 21.60s/it] 57%|█████▋ | 1065/1875 [7:08:57<4:55:39, 21.90s/it] 57%|█████▋ | 1066/1875 [7:09:19<4:53:03, 21.73s/it] 57%|█████▋ | 1067/1875 [7:09:41<4:55:14, 21.92s/it] 57%|█████▋ | 1068/1875 [7:10:03<4:56:53, 22.07s/it] 57%|█████▋ | 1069/1875 [7:10:27<5:00:50, 22.40s/it] 57%|█████▋ | 1070/1875 [7:10:50<5:06:22, 22.84s/it] 57%|█████▋ | 1070/1875 [7:10:50<5:06:22, 22.84s/it] 57%|█████▋ | 1071/1875 [7:11:08<4:43:47, 21.18s/it] 57%|█████▋ | 1072/1875 [7:11:32<4:57:26, 22.22s/it] 57%|█████▋ | 1073/1875 [7:11:57<5:04:54, 22.81s/it] 57%|█████▋ | 1074/1875 [7:12:16<4:52:34, 21.92s/it] 57%|█████▋ | 1075/1875 [7:12:41<5:03:59, 22.80s/it] 57%|█████▋ | 1076/1875 [7:13:01<4:52:23, 21.96s/it] 57%|█████▋ | 1077/1875 [7:13:18<4:31:29, 20.41s/it] 57%|█████▋ | 1078/1875 [7:13:41<4:41:47, 21.21s/it] 58%|█████▊ | 1079/1875 [7:14:00<4:30:23, 20.38s/it] 58%|█████▊ | 1080/1875 [7:14:23<4:42:50, 21.35s/it] 58%|█████▊ | 1080/1875 [7:14:23<4:42:50, 21.35s/it] 58%|█████▊ | 1081/1875 [7:14:44<4:38:21, 21.03s/it] 58%|█████▊ | 1082/1875 [7:15:03<4:33:35, 20.70s/it] 58%|█████▊ | 1083/1875 [7:15:22<4:25:44, 20.13s/it] 58%|█████▊ | 1084/1875 [7:15:45<4:37:06, 21.02s/it] 58%|█████▊ | 1085/1875 [7:16:07<4:38:11, 21.13s/it] 58%|█████▊ | 1086/1875 [7:16:27<4:35:10, 20.93s/it] 58%|█████▊ | 1087/1875 [7:16:46<4:28:23, 20.44s/it] 58%|█████▊ | 1088/1875 [7:17:09<4:36:53, 21.11s/it] 58%|█████▊ | 1089/1875 [7:17:35<4:53:29, 22.40s/it] 58%|█████▊ | 1090/1875 [7:17:59<5:02:54, 23.15s/it] 58%|█████▊ | 1090/1875 [7:17:59<5:02:54, 23.15s/it] 58%|█████▊ | 1091/1875 [7:18:19<4:47:16, 21.99s/it] 58%|█████▊ | 1092/1875 [7:18:38<4:37:57, 21.30s/it] 58%|█████▊ | 1093/1875 [7:18:59<4:33:49, 21.01s/it] 58%|█████▊ | 1094/1875 [7:19:26<4:58:41, 22.95s/it] 58%|█████▊ | 1095/1875 [7:19:43<4:35:12, 21.17s/it] 58%|█████▊ | 1096/1875 [7:20:02<4:26:35, 20.53s/it] 59%|█████▊ | 1097/1875 [7:20:26<4:36:38, 21.33s/it] 59%|█████▊ | 1098/1875 [7:20:36<3:54:43, 18.13s/it] 59%|█████▊ | 1099/1875 [7:21:04<4:33:36, 21.16s/it] 59%|█████▊ | 1100/1875 [7:21:24<4:26:19, 20.62s/it] 59%|█████▊ | 1100/1875 [7:21:24<4:26:19, 20.62s/it]{'loss': '0.01345', 'grad_norm': '8.562', 'learning_rate': '5e-06', 'num_tokens': '3.341e+07', 'completions/mean_length': '14.02', 'completions/min_length': '7', 'completions/max_length': '32.4', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '14.02', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '32.4', 'rewards/reward_function/mean': '0.1226', 'rewards/reward_function/std': '0.694', 'reward': '0.1226', 'reward_std': '0.694', 'frac_reward_zero_std': '0.5375', 'kl': '0.2247', 'entropy': '0.1391', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '19.23', 'epoch': '1.616'} +{'loss': '0.01641', 'grad_norm': '10.75', 'learning_rate': '5e-06', 'num_tokens': '3.373e+07', 'completions/mean_length': '13.95', 'completions/min_length': '6.8', 'completions/max_length': '42.1', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '13.95', 'completions/min_terminated_length': '6.8', 'completions/max_terminated_length': '42.1', 'rewards/reward_function/mean': '-0.04209', 'rewards/reward_function/std': '0.7373', 'reward': '-0.04209', 'reward_std': '0.7373', 'frac_reward_zero_std': '0.6188', 'kl': '0.2456', 'entropy': '0.1852', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '20.48', 'epoch': '1.632'} +{'loss': '0.02602', 'grad_norm': '5.688', 'learning_rate': '5e-06', 'num_tokens': '3.404e+07', 'completions/mean_length': '13.4', 'completions/min_length': '7', 'completions/max_length': '40.2', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '13.4', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '40.2', 'rewards/reward_function/mean': '0.07521', 'rewards/reward_function/std': '0.695', 'reward': '0.07521', 'reward_std': '0.695', 'frac_reward_zero_std': '0.6188', 'kl': '0.2505', 'entropy': '0.1921', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '20.26', 'epoch': '1.648'} +{'loss': '0.01673', 'grad_norm': '7.25', 'learning_rate': '5e-06', 'num_tokens': '3.439e+07', 'completions/mean_length': '13.27', 'completions/min_length': '7', 'completions/max_length': '34.3', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '13.27', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '34.3', 'rewards/reward_function/mean': '0.04157', 'rewards/reward_function/std': '0.7178', 'reward': '0.04157', 'reward_std': '0.7178', 'frac_reward_zero_std': '0.6937', 'kl': '0.2737', 'entropy': '0.1291', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '21.65', 'epoch': '1.664'} +{'loss': '0.01199', 'grad_norm': '14', 'learning_rate': '5e-06', 'num_tokens': '3.47e+07', 'completions/mean_length': '13.32', 'completions/min_length': '7', 'completions/max_length': '35.9', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '13.32', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '35.9', 'rewards/reward_function/mean': '0.04884', 'rewards/reward_function/std': '0.6952', 'reward': '0.04884', 'reward_std': '0.6952', 'frac_reward_zero_std': '0.6438', 'kl': '0.2668', 'entropy': '0.156', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '19.39', 'epoch': '1.68'} +{'loss': '0.01264', 'grad_norm': '13.12', 'learning_rate': '5e-06', 'num_tokens': '3.503e+07', 'completions/mean_length': '13.41', 'completions/min_length': '7', 'completions/max_length': '35.8', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '13.41', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '35.8', 'rewards/reward_function/mean': '0.06729', 'rewards/reward_function/std': '0.737', 'reward': '0.06729', 'reward_std': '0.737', 'frac_reward_zero_std': '0.5813', 'kl': '0.2794', 'entropy': '0.1646', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '20.69', 'epoch': '1.696'} +{'loss': '0.02006', 'grad_norm': '10.88', 'learning_rate': '5e-06', 'num_tokens': '3.538e+07', 'completions/mean_length': '13.89', 'completions/min_length': '7', 'completions/max_length': '33.4', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '13.89', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '33.4', 'rewards/reward_function/mean': '0.1455', 'rewards/reward_function/std': '0.6981', 'reward': '0.1455', 'reward_std': '0.6981', 'frac_reward_zero_std': '0.6438', 'kl': '0.2743', 'entropy': '0.1369', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '21.97', 'epoch': '1.712'} +{'loss': '0.03162', 'grad_norm': '12.81', 'learning_rate': '5e-06', 'num_tokens': '3.57e+07', 'completions/mean_length': '13.47', 'completions/min_length': '7', 'completions/max_length': '42', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '13.47', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '42', 'rewards/reward_function/mean': '0.009103', 'rewards/reward_function/std': '0.7054', 'reward': '0.009103', 'reward_std': '0.7054', 'frac_reward_zero_std': '0.575', 'kl': '0.259', 'entropy': '0.1626', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '21.2', 'epoch': '1.728'} +{'loss': '0.04219', 'grad_norm': '3.406', 'learning_rate': '5e-06', 'num_tokens': '3.602e+07', 'completions/mean_length': '14.05', 'completions/min_length': '7', 'completions/max_length': '46', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '14.05', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '46', 'rewards/reward_function/mean': '0.04001', 'rewards/reward_function/std': '0.6995', 'reward': '0.04001', 'reward_std': '0.6995', 'frac_reward_zero_std': '0.6438', 'kl': '0.2879', 'entropy': '0.18', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '21.54', 'epoch': '1.744'} +{'loss': '0.01481', 'grad_norm': '6.375', 'learning_rate': '5e-06', 'num_tokens': '3.635e+07', 'completions/mean_length': '14.08', 'completions/min_length': '7', 'completions/max_length': '41.3', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '14.08', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '41.3', 'rewards/reward_function/mean': '0.06397', 'rewards/reward_function/std': '0.7299', 'reward': '0.06397', 'reward_std': '0.7299', 'frac_reward_zero_std': '0.6438', 'kl': '0.2741', 'entropy': '0.2259', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '20.34', 'epoch': '1.76'} + 59%|█████▊ | 1101/1875 [7:22:53<8:52:57, 41.31s/it] 59%|█████▉ | 1102/1875 [7:23:14<7:32:46, 35.14s/it] 59%|█████▉ | 1103/1875 [7:23:37<6:43:54, 31.39s/it] 59%|█████▉ | 1104/1875 [7:23:56<5:55:29, 27.66s/it] 59%|█████▉ | 1105/1875 [7:24:15<5:22:08, 25.10s/it] 59%|█████▉ | 1106/1875 [7:24:32<4:52:42, 22.84s/it] 59%|█████▉ | 1107/1875 [7:24:51<4:37:59, 21.72s/it] 59%|█████▉ | 1108/1875 [7:25:12<4:33:29, 21.39s/it] 59%|█████▉ | 1109/1875 [7:25:32<4:27:37, 20.96s/it] 59%|█████▉ | 1110/1875 [7:25:51<4:18:42, 20.29s/it] 59%|█████▉ | 1110/1875 [7:25:51<4:18:42, 20.29s/it] 59%|█████▉ | 1111/1875 [7:26:10<4:12:28, 19.83s/it] 59%|█████▉ | 1112/1875 [7:26:32<4:21:05, 20.53s/it] 59%|█████▉ | 1113/1875 [7:26:55<4:31:18, 21.36s/it] 59%|█████▉ | 1114/1875 [7:27:16<4:28:32, 21.17s/it] 59%|█████▉ | 1115/1875 [7:27:36<4:25:40, 20.97s/it] 60%|█████▉ | 1116/1875 [7:27:56<4:22:12, 20.73s/it] 60%|█████▉ | 1117/1875 [7:28:18<4:25:40, 21.03s/it] 60%|█████▉ | 1118/1875 [7:28:41<4:32:48, 21.62s/it] 60%|█████▉ | 1119/1875 [7:29:01<4:24:10, 20.97s/it] 60%|█████▉ | 1120/1875 [7:29:20<4:16:58, 20.42s/it] 60%|█████▉ | 1120/1875 [7:29:20<4:16:58, 20.42s/it] 60%|█████▉ | 1121/1875 [7:29:40<4:14:32, 20.26s/it] 60%|█████▉ | 1122/1875 [7:30:01<4:19:02, 20.64s/it] 60%|█████▉ | 1123/1875 [7:30:24<4:28:24, 21.42s/it] 60%|█████▉ | 1124/1875 [7:30:37<3:55:25, 18.81s/it] 60%|██████ | 1125/1875 [7:30:56<3:53:52, 18.71s/it] 60%|██████ | 1126/1875 [7:31:17<4:03:53, 19.54s/it] 60%|██████ | 1127/1875 [7:31:33<3:51:52, 18.60s/it] 60%|██████ | 1128/1875 [7:31:52<3:52:23, 18.67s/it] 60%|██████ | 1129/1875 [7:32:16<4:10:00, 20.11s/it] 60%|██████ | 1130/1875 [7:32:40<4:25:40, 21.40s/it] 60%|██████ | 1130/1875 [7:32:40<4:25:40, 21.40s/it] 60%|██████ | 1131/1875 [7:33:04<4:32:36, 21.98s/it] 60%|██████ | 1132/1875 [7:33:22<4:17:35, 20.80s/it] 60%|██████ | 1133/1875 [7:33:43<4:18:50, 20.93s/it] 60%|██████ | 1134/1875 [7:34:02<4:13:16, 20.51s/it] 61%|██████ | 1135/1875 [7:34:23<4:14:08, 20.61s/it] 61%|██████ | 1136/1875 [7:34:41<4:04:05, 19.82s/it] 61%|██████ | 1137/1875 [7:35:00<4:01:51, 19.66s/it] 61%|██████ | 1138/1875 [7:35:21<4:04:35, 19.91s/it] 61%|██████ | 1139/1875 [7:35:41<4:04:01, 19.89s/it] 61%|██████ | 1140/1875 [7:36:02<4:08:26, 20.28s/it] 61%|██████ | 1140/1875 [7:36:02<4:08:26, 20.28s/it] 61%|██████ | 1141/1875 [7:36:22<4:06:32, 20.15s/it] 61%|██████ | 1142/1875 [7:36:41<4:02:02, 19.81s/it] 61%|██████ | 1143/1875 [7:36:59<3:54:52, 19.25s/it] 61%|██████ | 1144/1875 [7:37:20<4:00:36, 19.75s/it] 61%|██████ | 1145/1875 [7:37:45<4:18:49, 21.27s/it] 61%|██████ | 1146/1875 [7:38:07<4:24:38, 21.78s/it] 61%|██████ | 1147/1875 [7:38:27<4:15:49, 21.08s/it] 61%|██████ | 1148/1875 [7:38:46<4:07:28, 20.42s/it] 61%|██████▏ | 1149/1875 [7:39:05<4:03:14, 20.10s/it] 61%|██████▏ | 1150/1875 [7:39:32<4:27:39, 22.15s/it] 61%|██████▏ | 1150/1875 [7:39:32<4:27:39, 22.15s/it] 61%|██████▏ | 1151/1875 [7:39:52<4:17:36, 21.35s/it] 61%|██████▏ | 1152/1875 [7:40:13<4:16:21, 21.28s/it] 61%|██████▏ | 1153/1875 [7:40:31<4:05:32, 20.40s/it] 62%|██████▏ | 1154/1875 [7:40:51<4:03:18, 20.25s/it] 62%|██████▏ | 1155/1875 [7:41:09<3:54:29, 19.54s/it] 62%|██████▏ | 1156/1875 [7:41:29<3:55:48, 19.68s/it] 62%|██████▏ | 1157/1875 [7:41:51<4:03:57, 20.39s/it] 62%|██████▏ | 1158/1875 [7:42:08<3:51:38, 19.38s/it] 62%|██████▏ | 1159/1875 [7:42:33<4:11:21, 21.06s/it] 62%|██████▏ | 1160/1875 [7:42:51<4:00:44, 20.20s/it] 62%|██████▏ | 1160/1875 [7:42:51<4:00:44, 20.20s/it] 62%|██████▏ | 1161/1875 [7:43:13<4:06:53, 20.75s/it] 62%|██████▏ | 1162/1875 [7:43:34<4:05:40, 20.67s/it] 62%|██████▏ | 1163/1875 [7:43:57<4:14:20, 21.43s/it] 62%|██████▏ | 1164/1875 [7:44:16<4:06:09, 20.77s/it] 62%|██████▏ | 1165/1875 [7:44:42<4:22:53, 22.22s/it] 62%|██████▏ | 1166/1875 [7:45:02<4:17:31, 21.79s/it] 62%|██████▏ | 1167/1875 [7:45:26<4:23:25, 22.32s/it] 62%|██████▏ | 1168/1875 [7:45:49<4:24:39, 22.46s/it] 62%|██████▏ | 1169/1875 [7:46:12<4:25:44, 22.58s/it] 62%|██████▏ | 1170/1875 [7:46:36<4:30:01, 22.98s/it] 62%|██████▏ | 1170/1875 [7:46:36<4:30:01, 22.98s/it] 62%|██████▏ | 1171/1875 [7:46:52<4:07:57, 21.13s/it] 63%|██████▎ | 1172/1875 [7:47:14<4:11:02, 21.43s/it] 63%|██████▎ | 1173/1875 [7:47:34<4:04:29, 20.90s/it] 63%|██████▎ | 1174/1875 [7:47:54<4:00:34, 20.59s/it] 63%|██████▎ | 1175/1875 [7:48:17<4:07:54, 21.25s/it] 63%|██████▎ | 1176/1875 [7:48:39<4:10:56, 21.54s/it] 63%|██████▎ | 1177/1875 [7:49:02<4:15:28, 21.96s/it] 63%|██████▎ | 1178/1875 [7:49:23<4:10:14, 21.54s/it] 63%|██████▎ | 1179/1875 [7:49:42<4:01:32, 20.82s/it] 63%|██████▎ | 1180/1875 [7:50:05<4:10:30, 21.63s/it] 63%|██████▎ | 1180/1875 [7:50:05<4:10:30, 21.63s/it] 63%|██████▎ | 1181/1875 [7:50:26<4:07:08, 21.37s/it] 63%|██████▎ | 1182/1875 [7:50:49<4:11:32, 21.78s/it] 63%|██████▎ | 1183/1875 [7:51:11<4:12:17, 21.87s/it] 63%|██████▎ | 1184/1875 [7:51:31<4:06:58, 21.44s/it] 63%|██████▎ | 1185/1875 [7:51:51<4:01:26, 20.99s/it] 63%|██████▎ | 1186/1875 [7:52:16<4:13:43, 22.09s/it] 63%|██████▎ | 1187/1875 [7:52:35<4:04:04, 21.29s/it] 63%|██████▎ | 1188/1875 [7:52:53<3:52:44, 20.33s/it] 63%|██████▎ | 1189/1875 [7:53:16<4:00:31, 21.04s/it] 63%|██████▎ | 1190/1875 [7:53:39<4:07:38, 21.69s/it] 63%|██████▎ | 1190/1875 [7:53:39<4:07:38, 21.69s/it] 64%|██████▎ | 1191/1875 [7:53:59<3:59:06, 20.97s/it] 64%|██████▎ | 1192/1875 [7:54:23<4:10:53, 22.04s/it] 64%|██████▎ | 1193/1875 [7:54:45<4:09:51, 21.98s/it] 64%|██████▎ | 1194/1875 [7:55:04<4:00:55, 21.23s/it] 64%|██████▎ | 1195/1875 [7:55:26<4:01:42, 21.33s/it] 64%|██████▍ | 1196/1875 [7:55:50<4:09:34, 22.05s/it] 64%|██████▍ | 1197/1875 [7:56:09<3:58:15, 21.08s/it] 64%|██████▍ | 1198/1875 [7:56:27<3:50:41, 20.45s/it] 64%|██████▍ | 1199/1875 [7:56:51<3:59:08, 21.23s/it] 64%|██████▍ | 1200/1875 [7:57:14<4:06:56, 21.95s/it] 64%|██████▍ | 1200/1875 [7:57:14<4:06:56, 21.95s/it]{'loss': '0.01474', 'grad_norm': '47', 'learning_rate': '5e-06', 'num_tokens': '3.668e+07', 'completions/mean_length': '13.33', 'completions/min_length': '7', 'completions/max_length': '33', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '13.33', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '33', 'rewards/reward_function/mean': '0.03682', 'rewards/reward_function/std': '0.7333', 'reward': '0.03682', 'reward_std': '0.7333', 'frac_reward_zero_std': '0.6125', 'kl': '0.3013', 'entropy': '0.1307', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '20.28', 'epoch': '1.776'} +{'loss': '0.02292', 'grad_norm': '20', 'learning_rate': '5e-06', 'num_tokens': '3.701e+07', 'completions/mean_length': '13.1', 'completions/min_length': '7', 'completions/max_length': '32.5', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '13.1', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '32.5', 'rewards/reward_function/mean': '0.03945', 'rewards/reward_function/std': '0.7357', 'reward': '0.03945', 'reward_std': '0.7357', 'frac_reward_zero_std': '0.6562', 'kl': '0.2901', 'entropy': '0.1239', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '20.81', 'epoch': '1.792'} +{'loss': '0.02499', 'grad_norm': '5.406', 'learning_rate': '5e-06', 'num_tokens': '3.732e+07', 'completions/mean_length': '12.65', 'completions/min_length': '7', 'completions/max_length': '33.2', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '12.65', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '33.2', 'rewards/reward_function/mean': '0.05111', 'rewards/reward_function/std': '0.7069', 'reward': '0.05111', 'reward_std': '0.7069', 'frac_reward_zero_std': '0.7', 'kl': '0.2406', 'entropy': '0.1552', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '19.96', 'epoch': '1.808'} +{'loss': '0.03399', 'grad_norm': '9.438', 'learning_rate': '5e-06', 'num_tokens': '3.764e+07', 'completions/mean_length': '13.39', 'completions/min_length': '7', 'completions/max_length': '33.3', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '13.39', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '33.3', 'rewards/reward_function/mean': '0.01751', 'rewards/reward_function/std': '0.7247', 'reward': '0.01751', 'reward_std': '0.7247', 'frac_reward_zero_std': '0.6438', 'kl': '0.2921', 'entropy': '0.1583', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '20.09', 'epoch': '1.824'} +{'loss': '0.02256', 'grad_norm': '3.531', 'learning_rate': '5e-06', 'num_tokens': '3.795e+07', 'completions/mean_length': '13.11', 'completions/min_length': '7', 'completions/max_length': '47.7', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '13.11', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '47.7', 'rewards/reward_function/mean': '0.09939', 'rewards/reward_function/std': '0.7073', 'reward': '0.09939', 'reward_std': '0.7073', 'frac_reward_zero_std': '0.6875', 'kl': '0.2804', 'entropy': '0.1474', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '20.94', 'epoch': '1.84'} +{'loss': '0.01209', 'grad_norm': '3.828', 'learning_rate': '5e-06', 'num_tokens': '3.827e+07', 'completions/mean_length': '13.17', 'completions/min_length': '7', 'completions/max_length': '37.5', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '13.17', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '37.5', 'rewards/reward_function/mean': '0.02216', 'rewards/reward_function/std': '0.7076', 'reward': '0.02216', 'reward_std': '0.7076', 'frac_reward_zero_std': '0.625', 'kl': '0.2659', 'entropy': '0.1531', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '19.82', 'epoch': '1.856'} +{'loss': '0.03266', 'grad_norm': '8.688', 'learning_rate': '5e-06', 'num_tokens': '3.86e+07', 'completions/mean_length': '14.12', 'completions/min_length': '7', 'completions/max_length': '42.8', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '14.12', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '42.8', 'rewards/reward_function/mean': '0.008867', 'rewards/reward_function/std': '0.729', 'reward': '0.008867', 'reward_std': '0.729', 'frac_reward_zero_std': '0.6312', 'kl': '0.2146', 'entropy': '0.1576', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '22.36', 'epoch': '1.872'} +{'loss': '0.03482', 'grad_norm': '4.562', 'learning_rate': '5e-06', 'num_tokens': '3.895e+07', 'completions/mean_length': '13.85', 'completions/min_length': '7', 'completions/max_length': '35.4', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '13.85', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '35.4', 'rewards/reward_function/mean': '0.01655', 'rewards/reward_function/std': '0.7327', 'reward': '0.01655', 'reward_std': '0.7327', 'frac_reward_zero_std': '0.6375', 'kl': '0.2918', 'entropy': '0.1412', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '20.88', 'epoch': '1.888'} +{'loss': '0.01835', 'grad_norm': '13', 'learning_rate': '5e-06', 'num_tokens': '3.929e+07', 'completions/mean_length': '14.13', 'completions/min_length': '7', 'completions/max_length': '37.7', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '14.13', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '37.7', 'rewards/reward_function/mean': '0.1326', 'rewards/reward_function/std': '0.6737', 'reward': '0.1326', 'reward_std': '0.6737', 'frac_reward_zero_std': '0.6625', 'kl': '0.2428', 'entropy': '0.1547', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '21.32', 'epoch': '1.904'} +{'loss': '0.02323', 'grad_norm': '11.69', 'learning_rate': '5e-06', 'num_tokens': '3.965e+07', 'completions/mean_length': '13.69', 'completions/min_length': '7', 'completions/max_length': '30.6', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '13.69', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '30.6', 'rewards/reward_function/mean': '0.005084', 'rewards/reward_function/std': '0.7067', 'reward': '0.005084', 'reward_std': '0.7067', 'frac_reward_zero_std': '0.6', 'kl': '0.2772', 'entropy': '0.1451', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '21.4', 'epoch': '1.92'} + 64%|██████▍ | 1201/1875 [7:58:41<7:44:58, 41.39s/it] 64%|██████▍ | 1202/1875 [7:58:59<6:27:00, 34.50s/it] 64%|██████▍ | 1203/1875 [7:59:26<5:58:33, 32.01s/it] 64%|██████▍ | 1204/1875 [7:59:48<5:27:14, 29.26s/it] 64%|██████▍ | 1205/1875 [8:00:12<5:08:39, 27.64s/it] 64%|██████▍ | 1206/1875 [8:00:33<4:46:34, 25.70s/it] 64%|██████▍ | 1207/1875 [8:00:57<4:38:46, 25.04s/it] 64%|██████▍ | 1208/1875 [8:01:16<4:17:48, 23.19s/it] 64%|██████▍ | 1209/1875 [8:01:40<4:19:52, 23.41s/it] 65%|██████▍ | 1210/1875 [8:02:05<4:26:23, 24.04s/it] 65%|██████▍ | 1210/1875 [8:02:05<4:26:23, 24.04s/it] 65%|██████▍ | 1211/1875 [8:02:24<4:10:12, 22.61s/it] 65%|██████▍ | 1212/1875 [8:02:50<4:18:49, 23.42s/it] 65%|██████▍ | 1213/1875 [8:03:15<4:22:40, 23.81s/it] 65%|██████▍ | 1214/1875 [8:03:33<4:03:37, 22.11s/it] 65%|██████▍ | 1215/1875 [8:03:54<4:02:01, 22.00s/it] 65%|██████▍ | 1216/1875 [8:04:12<3:46:52, 20.66s/it] 65%|██████▍ | 1217/1875 [8:04:34<3:52:20, 21.19s/it] 65%|██████▍ | 1218/1875 [8:04:56<3:52:09, 21.20s/it] 65%|██████▌ | 1219/1875 [8:05:15<3:46:04, 20.68s/it] 65%|██████▌ | 1220/1875 [8:05:40<4:00:38, 22.04s/it] 65%|██████▌ | 1220/1875 [8:05:40<4:00:38, 22.04s/it] 65%|██████▌ | 1221/1875 [8:06:03<4:03:12, 22.31s/it] 65%|██████▌ | 1222/1875 [8:06:28<4:12:10, 23.17s/it] 65%|██████▌ | 1223/1875 [8:06:50<4:05:29, 22.59s/it] 65%|██████▌ | 1224/1875 [8:07:12<4:03:41, 22.46s/it] 65%|██████▌ | 1225/1875 [8:07:36<4:09:01, 22.99s/it] 65%|██████▌ | 1226/1875 [8:08:01<4:13:59, 23.48s/it] 65%|██████▌ | 1227/1875 [8:08:25<4:15:56, 23.70s/it] 65%|██████▌ | 1228/1875 [8:08:43<3:58:51, 22.15s/it] 66%|██████▌ | 1229/1875 [8:09:06<4:01:20, 22.42s/it] 66%|██████▌ | 1230/1875 [8:09:27<3:53:56, 21.76s/it] 66%|██████▌ | 1230/1875 [8:09:27<3:53:56, 21.76s/it] 66%|██████▌ | 1231/1875 [8:09:47<3:49:39, 21.40s/it] 66%|██████▌ | 1232/1875 [8:10:04<3:35:48, 20.14s/it] 66%|██████▌ | 1233/1875 [8:10:26<3:41:25, 20.69s/it] 66%|██████▌ | 1234/1875 [8:10:51<3:53:20, 21.84s/it] 66%|██████▌ | 1235/1875 [8:11:16<4:04:26, 22.92s/it] 66%|██████▌ | 1236/1875 [8:11:33<3:43:35, 20.99s/it] 66%|██████▌ | 1237/1875 [8:11:48<3:23:15, 19.12s/it] 66%|██████▌ | 1238/1875 [8:12:11<3:36:24, 20.38s/it] 66%|██████▌ | 1239/1875 [8:12:35<3:49:22, 21.64s/it] 66%|██████▌ | 1240/1875 [8:12:58<3:53:08, 22.03s/it] 66%|██████▌ | 1240/1875 [8:12:58<3:53:08, 22.03s/it] 66%|██████▌ | 1241/1875 [8:13:20<3:51:10, 21.88s/it] 66%|██████▌ | 1242/1875 [8:13:44<3:57:54, 22.55s/it] 66%|██████▋ | 1243/1875 [8:14:06<3:55:19, 22.34s/it] 66%|██████▋ | 1244/1875 [8:14:24<3:42:11, 21.13s/it] 66%|██████▋ | 1245/1875 [8:14:44<3:36:15, 20.60s/it] 66%|██████▋ | 1246/1875 [8:15:01<3:26:58, 19.74s/it] 67%|██████▋ | 1247/1875 [8:15:25<3:39:03, 20.93s/it] 67%|██████▋ | 1248/1875 [8:15:45<3:35:26, 20.62s/it] 67%|██████▋ | 1249/1875 [8:16:09<3:44:52, 21.55s/it] 67%|██████▋ | 1250/1875 [8:16:29<3:40:12, 21.14s/it] 67%|██████▋ | 1250/1875 [8:16:29<3:40:12, 21.14s/it] 67%|██████▋ | 1251/1875 [8:16:51<3:42:24, 21.39s/it] 67%|██████▋ | 1252/1875 [8:17:09<3:30:39, 20.29s/it] 67%|██████▋ | 1253/1875 [8:17:29<3:30:47, 20.33s/it] 67%|██████▋ | 1254/1875 [8:17:57<3:54:30, 22.66s/it] 67%|██████▋ | 1255/1875 [8:18:14<3:35:51, 20.89s/it] 67%|██████▋ | 1256/1875 [8:18:35<3:36:04, 20.94s/it] 67%|██████▋ | 1257/1875 [8:18:55<3:34:41, 20.84s/it] 67%|██████▋ | 1258/1875 [8:19:12<3:21:45, 19.62s/it] 67%|██████▋ | 1259/1875 [8:19:30<3:14:49, 18.98s/it] 67%|██████▋ | 1260/1875 [8:19:48<3:11:39, 18.70s/it] 67%|██████▋ | 1260/1875 [8:19:48<3:11:39, 18.70s/it] 67%|██████▋ | 1261/1875 [8:20:07<3:14:07, 18.97s/it] 67%|██████▋ | 1262/1875 [8:20:31<3:27:01, 20.26s/it] 67%|██████▋ | 1263/1875 [8:20:48<3:18:51, 19.50s/it] 67%|██████▋ | 1264/1875 [8:21:10<3:25:25, 20.17s/it] 67%|██████▋ | 1265/1875 [8:21:29<3:20:12, 19.69s/it] 68%|██████▊ | 1266/1875 [8:21:54<3:35:40, 21.25s/it] 68%|██████▊ | 1267/1875 [8:22:15<3:34:24, 21.16s/it] 68%|██████▊ | 1268/1875 [8:22:31<3:21:02, 19.87s/it] 68%|██████▊ | 1269/1875 [8:22:56<3:33:47, 21.17s/it] 68%|██████▊ | 1270/1875 [8:23:16<3:30:49, 20.91s/it] 68%|██████▊ | 1270/1875 [8:23:16<3:30:49, 20.91s/it] 68%|██████▊ | 1271/1875 [8:23:37<3:31:28, 21.01s/it] 68%|██████▊ | 1272/1875 [8:24:01<3:40:37, 21.95s/it] 68%|██████▊ | 1273/1875 [8:24:20<3:30:07, 20.94s/it] 68%|██████▊ | 1274/1875 [8:24:46<3:44:09, 22.38s/it] 68%|██████▊ | 1275/1875 [8:25:09<3:47:27, 22.75s/it] 68%|██████▊ | 1276/1875 [8:25:27<3:33:01, 21.34s/it] 68%|██████▊ | 1277/1875 [8:25:53<3:45:11, 22.59s/it] 68%|██████▊ | 1278/1875 [8:26:11<3:31:38, 21.27s/it] 68%|██████▊ | 1279/1875 [8:26:35<3:38:12, 21.97s/it] 68%|██████▊ | 1280/1875 [8:26:55<3:33:10, 21.50s/it] 68%|██████▊ | 1280/1875 [8:26:55<3:33:10, 21.50s/it] 68%|██████▊ | 1281/1875 [8:27:14<3:25:36, 20.77s/it] 68%|██████▊ | 1282/1875 [8:27:38<3:34:28, 21.70s/it] 68%|██████▊ | 1283/1875 [8:27:58<3:29:08, 21.20s/it] 68%|██████▊ | 1284/1875 [8:28:21<3:35:34, 21.89s/it] 69%|██████▊ | 1285/1875 [8:28:40<3:25:31, 20.90s/it] 69%|██████▊ | 1286/1875 [8:29:04<3:33:09, 21.71s/it] 69%|██████▊ | 1287/1875 [8:29:22<3:22:23, 20.65s/it] 69%|██████▊ | 1288/1875 [8:29:40<3:13:44, 19.80s/it] 69%|██████▊ | 1289/1875 [8:29:58<3:10:32, 19.51s/it] 69%|██████▉ | 1290/1875 [8:30:23<3:24:33, 20.98s/it] 69%|██████▉ | 1290/1875 [8:30:23<3:24:33, 20.98s/it] 69%|██████▉ | 1291/1875 [8:30:45<3:28:33, 21.43s/it] 69%|██████▉ | 1292/1875 [8:31:03<3:18:11, 20.40s/it] 69%|██████▉ | 1293/1875 [8:31:20<3:06:53, 19.27s/it] 69%|██████▉ | 1294/1875 [8:31:45<3:23:49, 21.05s/it] 69%|██████▉ | 1295/1875 [8:32:12<3:39:04, 22.66s/it] 69%|██████▉ | 1296/1875 [8:32:35<3:42:00, 23.01s/it] 69%|██████▉ | 1297/1875 [8:32:58<3:41:47, 23.02s/it] 69%|██████▉ | 1298/1875 [8:33:16<3:25:51, 21.41s/it] 69%|██████▉ | 1299/1875 [8:33:55<4:17:04, 26.78s/it] 69%|██████▉ | 1300/1875 [8:34:14<3:52:05, 24.22s/it] 69%|██████▉ | 1300/1875 [8:34:14<3:52:05, 24.22s/it]{'loss': '0.03418', 'grad_norm': '4.5', 'learning_rate': '5e-06', 'num_tokens': '4e+07', 'completions/mean_length': '14.71', 'completions/min_length': '7', 'completions/max_length': '42.7', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '14.71', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '42.7', 'rewards/reward_function/mean': '0.04243', 'rewards/reward_function/std': '0.6728', 'reward': '0.04243', 'reward_std': '0.6728', 'frac_reward_zero_std': '0.6125', 'kl': '0.2883', 'entropy': '0.1552', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '22.6', 'epoch': '1.936'} +{'loss': '0.02111', 'grad_norm': '9.188', 'learning_rate': '5e-06', 'num_tokens': '4.033e+07', 'completions/mean_length': '14.31', 'completions/min_length': '7', 'completions/max_length': '40.2', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '14.31', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '40.2', 'rewards/reward_function/mean': '0.04597', 'rewards/reward_function/std': '0.7033', 'reward': '0.04597', 'reward_std': '0.7033', 'frac_reward_zero_std': '0.6562', 'kl': '0.2175', 'entropy': '0.162', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '21.42', 'epoch': '1.952'} +{'loss': '0.02565', 'grad_norm': '8.125', 'learning_rate': '5e-06', 'num_tokens': '4.067e+07', 'completions/mean_length': '13.47', 'completions/min_length': '7', 'completions/max_length': '36.4', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '13.47', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '36.4', 'rewards/reward_function/mean': '-0.0002169', 'rewards/reward_function/std': '0.7255', 'reward': '-0.0002169', 'reward_std': '0.7255', 'frac_reward_zero_std': '0.65', 'kl': '0.2465', 'entropy': '0.1322', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '22.55', 'epoch': '1.968'} +{'loss': '0.02874', 'grad_norm': '6.094', 'learning_rate': '5e-06', 'num_tokens': '4.099e+07', 'completions/mean_length': '13.72', 'completions/min_length': '7', 'completions/max_length': '37.1', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '13.72', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '37.1', 'rewards/reward_function/mean': '0.0394', 'rewards/reward_function/std': '0.7124', 'reward': '0.0394', 'reward_std': '0.7124', 'frac_reward_zero_std': '0.6', 'kl': '0.2557', 'entropy': '0.1758', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '21.09', 'epoch': '1.984'} +{'loss': '0.03941', 'grad_norm': '9.562', 'learning_rate': '5e-06', 'num_tokens': '4.132e+07', 'completions/mean_length': '13.3', 'completions/min_length': '7', 'completions/max_length': '34.9', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '13.3', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '34.9', 'rewards/reward_function/mean': '0.02042', 'rewards/reward_function/std': '0.7211', 'reward': '0.02042', 'reward_std': '0.7211', 'frac_reward_zero_std': '0.5687', 'kl': '0.2635', 'entropy': '0.1389', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '20.96', 'epoch': '2'} +{'loss': '0.02399', 'grad_norm': '8.688', 'learning_rate': '5e-06', 'num_tokens': '4.163e+07', 'completions/mean_length': '13.84', 'completions/min_length': '7', 'completions/max_length': '35', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '13.84', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '35', 'rewards/reward_function/mean': '0.145', 'rewards/reward_function/std': '0.6499', 'reward': '0.145', 'reward_std': '0.6499', 'frac_reward_zero_std': '0.6062', 'kl': '0.2704', 'entropy': '0.1226', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '19.81', 'epoch': '2.016'} +{'loss': '0.02697', 'grad_norm': '6.5', 'learning_rate': '5e-06', 'num_tokens': '4.195e+07', 'completions/mean_length': '13.32', 'completions/min_length': '7', 'completions/max_length': '31', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '13.32', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '31', 'rewards/reward_function/mean': '0.149', 'rewards/reward_function/std': '0.6627', 'reward': '0.149', 'reward_std': '0.6627', 'frac_reward_zero_std': '0.6312', 'kl': '0.2559', 'entropy': '0.1241', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '20.73', 'epoch': '2.032'} +{'loss': '0.04065', 'grad_norm': '11.06', 'learning_rate': '5e-06', 'num_tokens': '4.231e+07', 'completions/mean_length': '14.01', 'completions/min_length': '7', 'completions/max_length': '34.9', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '14.01', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '34.9', 'rewards/reward_function/mean': '0.1465', 'rewards/reward_function/std': '0.6574', 'reward': '0.1465', 'reward_std': '0.6574', 'frac_reward_zero_std': '0.6', 'kl': '0.2812', 'entropy': '0.1285', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '21.82', 'epoch': '2.048'} +{'loss': '0.0344', 'grad_norm': '3.828', 'learning_rate': '5e-06', 'num_tokens': '4.261e+07', 'completions/mean_length': '13.88', 'completions/min_length': '7', 'completions/max_length': '41.9', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '13.88', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '41.9', 'rewards/reward_function/mean': '0.117', 'rewards/reward_function/std': '0.7087', 'reward': '0.117', 'reward_std': '0.7087', 'frac_reward_zero_std': '0.6813', 'kl': '0.2577', 'entropy': '0.1492', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '20.71', 'epoch': '2.064'} +{'loss': '0.04826', 'grad_norm': '7.406', 'learning_rate': '5e-06', 'num_tokens': '4.292e+07', 'completions/mean_length': '13.24', 'completions/min_length': '7', 'completions/max_length': '61.1', 'completions/clipped_ratio': '0.001563', 'completions/mean_terminated_length': '12.86', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '40.7', 'rewards/reward_function/mean': '0.0786', 'rewards/reward_function/std': '0.7012', 'reward': '0.0786', 'reward_std': '0.7012', 'frac_reward_zero_std': '0.675', 'kl': '0.272', 'entropy': '0.365', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '23', 'epoch': '2.08'} + 69%|██████▉ | 1301/1875 [8:35:37<6:41:17, 41.95s/it] 69%|██████▉ | 1302/1875 [8:36:00<5:46:11, 36.25s/it] 69%|██████▉ | 1303/1875 [8:36:19<4:55:39, 31.01s/it] 70%|██████▉ | 1304/1875 [8:36:41<4:31:37, 28.54s/it] 70%|██████▉ | 1305/1875 [8:37:00<4:01:18, 25.40s/it] 70%|██████▉ | 1306/1875 [8:37:20<3:46:20, 23.87s/it] 70%|██████▉ | 1307/1875 [8:37:40<3:36:26, 22.86s/it] 70%|██████▉ | 1308/1875 [8:37:58<3:22:34, 21.44s/it] 70%|██████▉ | 1309/1875 [8:38:16<3:10:09, 20.16s/it] 70%|██████▉ | 1310/1875 [8:38:37<3:12:54, 20.49s/it] 70%|██████▉ | 1310/1875 [8:38:37<3:12:54, 20.49s/it] 70%|██████▉ | 1311/1875 [8:38:56<3:09:10, 20.13s/it] 70%|██████▉ | 1312/1875 [8:39:20<3:18:08, 21.12s/it] 70%|███████ | 1313/1875 [8:39:39<3:14:01, 20.71s/it] 70%|███████ | 1314/1875 [8:39:59<3:10:37, 20.39s/it] 70%|███████ | 1315/1875 [8:40:20<3:11:31, 20.52s/it] 70%|███████ | 1316/1875 [8:40:38<3:05:24, 19.90s/it] 70%|███████ | 1317/1875 [8:41:00<3:10:01, 20.43s/it] 70%|███████ | 1318/1875 [8:41:23<3:17:30, 21.28s/it] 70%|███████ | 1319/1875 [8:41:41<3:08:08, 20.30s/it] 70%|███████ | 1320/1875 [8:41:57<2:55:39, 18.99s/it] 70%|███████ | 1320/1875 [8:41:57<2:55:39, 18.99s/it] 70%|███████ | 1321/1875 [8:42:21<3:07:23, 20.30s/it] 71%|███████ | 1322/1875 [8:42:41<3:08:51, 20.49s/it] 71%|███████ | 1323/1875 [8:43:05<3:15:34, 21.26s/it] 71%|███████ | 1324/1875 [8:43:28<3:21:11, 21.91s/it] 71%|███████ | 1325/1875 [8:43:44<3:05:47, 20.27s/it] 71%|███████ | 1326/1875 [8:44:02<2:58:21, 19.49s/it] 71%|███████ | 1327/1875 [8:44:25<3:08:43, 20.66s/it] 71%|███████ | 1328/1875 [8:44:46<3:08:34, 20.69s/it] 71%|███████ | 1329/1875 [8:45:07<3:08:43, 20.74s/it] 71%|███████ | 1330/1875 [8:45:29<3:12:37, 21.21s/it] 71%|███████ | 1330/1875 [8:45:29<3:12:37, 21.21s/it] 71%|███████ | 1331/1875 [8:45:46<2:59:11, 19.76s/it] 71%|███████ | 1332/1875 [8:46:09<3:07:39, 20.74s/it] 71%|███████ | 1333/1875 [8:46:34<3:19:42, 22.11s/it] 71%|███████ | 1334/1875 [8:46:54<3:14:05, 21.53s/it] 71%|███████ | 1335/1875 [8:47:15<3:12:09, 21.35s/it] 71%|███████▏ | 1336/1875 [8:47:37<3:14:07, 21.61s/it] 71%|███████▏ | 1337/1875 [8:47:58<3:10:35, 21.26s/it] 71%|███████▏ | 1338/1875 [8:48:17<3:04:35, 20.62s/it] 71%|███████▏ | 1339/1875 [8:48:36<3:01:14, 20.29s/it] 71%|███████▏ | 1340/1875 [8:48:57<3:01:24, 20.34s/it] 71%|███████▏ | 1340/1875 [8:48:57<3:01:24, 20.34s/it] 72%|███████▏ | 1341/1875 [8:49:17<2:59:30, 20.17s/it] 72%|███████▏ | 1342/1875 [8:49:37<3:00:10, 20.28s/it] 72%|███████▏ | 1343/1875 [8:49:56<2:56:10, 19.87s/it] 72%|███████▏ | 1344/1875 [8:50:16<2:56:40, 19.96s/it] 72%|███████▏ | 1345/1875 [8:50:32<2:44:49, 18.66s/it] 72%|███████▏ | 1346/1875 [8:50:49<2:41:17, 18.29s/it] 72%|███████▏ | 1347/1875 [8:51:11<2:49:00, 19.20s/it] 72%|███████▏ | 1348/1875 [8:51:33<2:57:49, 20.25s/it] 72%|███████▏ | 1349/1875 [8:51:52<2:54:07, 19.86s/it] 72%|███████▏ | 1350/1875 [8:52:13<2:56:30, 20.17s/it] 72%|███████▏ | 1350/1875 [8:52:13<2:56:30, 20.17s/it] 72%|███████▏ | 1351/1875 [8:52:38<3:08:54, 21.63s/it] 72%|███████▏ | 1352/1875 [8:53:00<3:07:33, 21.52s/it] 72%|███████▏ | 1353/1875 [8:53:25<3:16:30, 22.59s/it] 72%|███████▏ | 1354/1875 [8:53:48<3:18:21, 22.84s/it] 72%|███████▏ | 1355/1875 [8:54:11<3:18:43, 22.93s/it] 72%|███████▏ | 1356/1875 [8:54:34<3:19:14, 23.03s/it] 72%|███████▏ | 1357/1875 [8:54:58<3:20:53, 23.27s/it] 72%|███████▏ | 1358/1875 [8:55:20<3:17:43, 22.95s/it] 72%|███████▏ | 1359/1875 [8:55:36<2:58:29, 20.75s/it] 73%|███████▎ | 1360/1875 [8:55:53<2:47:32, 19.52s/it] 73%|███████▎ | 1360/1875 [8:55:53<2:47:32, 19.52s/it] 73%|███████▎ | 1361/1875 [8:56:16<2:57:24, 20.71s/it] 73%|███████▎ | 1362/1875 [8:56:34<2:48:54, 19.76s/it] 73%|███████▎ | 1363/1875 [8:56:58<3:00:33, 21.16s/it] 73%|███████▎ | 1364/1875 [8:57:17<2:52:53, 20.30s/it] 73%|███████▎ | 1365/1875 [8:57:36<2:50:07, 20.02s/it] 73%|███████▎ | 1366/1875 [8:57:58<2:54:35, 20.58s/it] 73%|███████▎ | 1367/1875 [8:58:18<2:52:54, 20.42s/it] 73%|███████▎ | 1368/1875 [8:58:43<3:04:09, 21.79s/it] 73%|███████▎ | 1369/1875 [8:59:06<3:08:14, 22.32s/it] 73%|███████▎ | 1370/1875 [8:59:32<3:16:57, 23.40s/it] 73%|███████▎ | 1370/1875 [8:59:32<3:16:57, 23.40s/it] 73%|███████▎ | 1371/1875 [8:59:49<2:59:33, 21.38s/it] 73%|███████▎ | 1372/1875 [9:00:13<3:05:30, 22.13s/it] 73%|███████▎ | 1373/1875 [9:00:32<2:57:17, 21.19s/it] 73%|███████▎ | 1374/1875 [9:00:50<2:48:52, 20.22s/it] 73%|███████▎ | 1375/1875 [9:01:14<2:58:21, 21.40s/it] 73%|███████▎ | 1376/1875 [9:01:32<2:49:46, 20.41s/it] 73%|███████▎ | 1377/1875 [9:01:51<2:45:22, 19.92s/it] 73%|███████▎ | 1378/1875 [9:02:10<2:43:34, 19.75s/it] 74%|███████▎ | 1379/1875 [9:02:34<2:52:54, 20.92s/it] 74%|███████▎ | 1380/1875 [9:03:01<3:09:01, 22.91s/it] 74%|███████▎ | 1380/1875 [9:03:01<3:09:01, 22.91s/it] 74%|███████▎ | 1381/1875 [9:03:20<2:58:57, 21.74s/it] 74%|███████▎ | 1382/1875 [9:03:42<2:58:06, 21.68s/it] 74%|███████▍ | 1383/1875 [9:04:04<2:59:12, 21.85s/it] 74%|███████▍ | 1384/1875 [9:04:21<2:47:41, 20.49s/it] 74%|███████▍ | 1385/1875 [9:04:44<2:51:25, 20.99s/it] 74%|███████▍ | 1386/1875 [9:05:02<2:44:09, 20.14s/it] 74%|███████▍ | 1387/1875 [9:05:20<2:38:52, 19.53s/it] 74%|███████▍ | 1388/1875 [9:05:41<2:42:58, 20.08s/it] 74%|███████▍ | 1389/1875 [9:06:02<2:44:58, 20.37s/it] 74%|███████▍ | 1390/1875 [9:06:20<2:37:47, 19.52s/it] 74%|███████▍ | 1390/1875 [9:06:20<2:37:47, 19.52s/it] 74%|███████▍ | 1391/1875 [9:06:45<2:51:02, 21.20s/it] 74%|███████▍ | 1392/1875 [9:07:04<2:44:20, 20.41s/it] 74%|███████▍ | 1393/1875 [9:07:26<2:49:54, 21.15s/it] 74%|███████▍ | 1394/1875 [9:07:51<2:58:36, 22.28s/it] 74%|███████▍ | 1395/1875 [9:08:18<3:08:51, 23.61s/it] 74%|███████▍ | 1396/1875 [9:08:35<2:52:39, 21.63s/it] 75%|███████▍ | 1397/1875 [9:08:58<2:55:53, 22.08s/it] 75%|███████▍ | 1398/1875 [9:09:20<2:54:02, 21.89s/it] 75%|███████▍ | 1399/1875 [9:09:43<2:58:17, 22.47s/it] 75%|███████▍ | 1400/1875 [9:10:02<2:49:28, 21.41s/it] 75%|███████▍ | 1400/1875 [9:10:02<2:49:28, 21.41s/it]{'loss': '0.01317', 'grad_norm': '1.539', 'learning_rate': '5e-06', 'num_tokens': '4.325e+07', 'completions/mean_length': '13.62', 'completions/min_length': '7', 'completions/max_length': '29.8', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '13.62', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '29.8', 'rewards/reward_function/mean': '0.2533', 'rewards/reward_function/std': '0.6405', 'reward': '0.2533', 'reward_std': '0.6405', 'frac_reward_zero_std': '0.75', 'kl': '0.2498', 'entropy': '0.111', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '19.89', 'epoch': '2.096'} +{'loss': '0.01531', 'grad_norm': '9.125', 'learning_rate': '5e-06', 'num_tokens': '4.358e+07', 'completions/mean_length': '13.17', 'completions/min_length': '7', 'completions/max_length': '29.8', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '13.17', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '29.8', 'rewards/reward_function/mean': '0.0635', 'rewards/reward_function/std': '0.7231', 'reward': '0.0635', 'reward_std': '0.7231', 'frac_reward_zero_std': '0.6625', 'kl': '0.2922', 'entropy': '0.1164', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '19.94', 'epoch': '2.112'} +{'loss': '0.008194', 'grad_norm': '5.219', 'learning_rate': '5e-06', 'num_tokens': '4.39e+07', 'completions/mean_length': '13.32', 'completions/min_length': '7', 'completions/max_length': '33.8', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '13.32', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '33.8', 'rewards/reward_function/mean': '0.06899', 'rewards/reward_function/std': '0.7181', 'reward': '0.06899', 'reward_std': '0.7181', 'frac_reward_zero_std': '0.7063', 'kl': '0.2537', 'entropy': '0.1503', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '21.14', 'epoch': '2.128'} +{'loss': '0.0344', 'grad_norm': '6.344', 'learning_rate': '5e-06', 'num_tokens': '4.423e+07', 'completions/mean_length': '13.21', 'completions/min_length': '7', 'completions/max_length': '34.5', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '13.21', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '34.5', 'rewards/reward_function/mean': '0.06097', 'rewards/reward_function/std': '0.694', 'reward': '0.06097', 'reward_std': '0.694', 'frac_reward_zero_std': '0.6062', 'kl': '0.2771', 'entropy': '0.1454', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '20.68', 'epoch': '2.144'} +{'loss': '0.01475', 'grad_norm': '8.062', 'learning_rate': '5e-06', 'num_tokens': '4.455e+07', 'completions/mean_length': '12.87', 'completions/min_length': '7', 'completions/max_length': '30.7', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '12.87', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '30.7', 'rewards/reward_function/mean': '0.09332', 'rewards/reward_function/std': '0.6799', 'reward': '0.09332', 'reward_std': '0.6799', 'frac_reward_zero_std': '0.625', 'kl': '0.2978', 'entropy': '0.1194', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '19.55', 'epoch': '2.16'} +{'loss': '0.04098', 'grad_norm': '12.38', 'learning_rate': '5e-06', 'num_tokens': '4.487e+07', 'completions/mean_length': '13.63', 'completions/min_length': '7', 'completions/max_length': '41.1', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '13.63', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '41.1', 'rewards/reward_function/mean': '0.1212', 'rewards/reward_function/std': '0.6959', 'reward': '0.1212', 'reward_std': '0.6959', 'frac_reward_zero_std': '0.675', 'kl': '0.2402', 'entropy': '0.1283', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '21.87', 'epoch': '2.176'} +{'loss': '0.005355', 'grad_norm': '6.344', 'learning_rate': '5e-06', 'num_tokens': '4.521e+07', 'completions/mean_length': '13.5', 'completions/min_length': '7', 'completions/max_length': '36.7', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '13.5', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '36.7', 'rewards/reward_function/mean': '0.2016', 'rewards/reward_function/std': '0.6574', 'reward': '0.2016', 'reward_std': '0.6574', 'frac_reward_zero_std': '0.6687', 'kl': '0.2626', 'entropy': '0.09899', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '21.86', 'epoch': '2.192'} +{'loss': '0.009325', 'grad_norm': '19.75', 'learning_rate': '5e-06', 'num_tokens': '4.554e+07', 'completions/mean_length': '13.5', 'completions/min_length': '7', 'completions/max_length': '35.9', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '13.5', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '35.9', 'rewards/reward_function/mean': '0.1377', 'rewards/reward_function/std': '0.7161', 'reward': '0.1377', 'reward_std': '0.7161', 'frac_reward_zero_std': '0.6687', 'kl': '0.2585', 'entropy': '0.1257', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '20.83', 'epoch': '2.208'} +{'loss': '0.0167', 'grad_norm': '7.344', 'learning_rate': '5e-06', 'num_tokens': '4.585e+07', 'completions/mean_length': '12.8', 'completions/min_length': '7', 'completions/max_length': '30.8', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '12.8', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '30.8', 'rewards/reward_function/mean': '0.1214', 'rewards/reward_function/std': '0.7031', 'reward': '0.1214', 'reward_std': '0.7031', 'frac_reward_zero_std': '0.7188', 'kl': '0.2627', 'entropy': '0.1135', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '19.77', 'epoch': '2.224'} +{'loss': '0.0008669', 'grad_norm': '7.938', 'learning_rate': '5e-06', 'num_tokens': '4.621e+07', 'completions/mean_length': '13.61', 'completions/min_length': '7', 'completions/max_length': '34.8', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '13.61', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '34.8', 'rewards/reward_function/mean': '0.1223', 'rewards/reward_function/std': '0.6582', 'reward': '0.1223', 'reward_std': '0.6582', 'frac_reward_zero_std': '0.675', 'kl': '0.3147', 'entropy': '0.1098', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '22.11', 'epoch': '2.24'} + 75%|███████▍ | 1401/1875 [9:11:29<5:23:23, 40.94s/it] 75%|███████▍ | 1402/1875 [9:11:49<4:33:24, 34.68s/it] 75%|███████▍ | 1403/1875 [9:12:14<4:09:46, 31.75s/it] 75%|███████▍ | 1404/1875 [9:12:36<3:46:23, 28.84s/it] 75%|███████▍ | 1405/1875 [9:13:00<3:33:31, 27.26s/it] 75%|███████▍ | 1406/1875 [9:13:23<3:23:54, 26.09s/it] 75%|███████▌ | 1407/1875 [9:13:41<3:03:46, 23.56s/it] 75%|███████▌ | 1408/1875 [9:14:01<2:55:05, 22.50s/it] 75%|███████▌ | 1409/1875 [9:14:21<2:49:30, 21.83s/it] 75%|███████▌ | 1410/1875 [9:14:40<2:43:36, 21.11s/it] 75%|███████▌ | 1410/1875 [9:14:40<2:43:36, 21.11s/it] 75%|███████▌ | 1411/1875 [9:14:58<2:36:02, 20.18s/it] 75%|███████▌ | 1412/1875 [9:15:20<2:38:32, 20.54s/it] 75%|███████▌ | 1413/1875 [9:15:38<2:33:06, 19.88s/it] 75%|███████▌ | 1414/1875 [9:15:58<2:32:12, 19.81s/it] 75%|███████▌ | 1415/1875 [9:16:17<2:30:33, 19.64s/it] 76%|███████▌ | 1416/1875 [9:16:37<2:31:34, 19.81s/it] 76%|███████▌ | 1417/1875 [9:17:01<2:39:50, 20.94s/it] 76%|███████▌ | 1418/1875 [9:17:24<2:44:18, 21.57s/it] 76%|███████▌ | 1419/1875 [9:17:45<2:42:31, 21.38s/it] 76%|███████▌ | 1420/1875 [9:18:11<2:53:14, 22.85s/it] 76%|███████▌ | 1420/1875 [9:18:11<2:53:14, 22.85s/it] 76%|███████▌ | 1421/1875 [9:18:30<2:45:13, 21.84s/it] 76%|███████▌ | 1422/1875 [9:18:56<2:52:37, 22.86s/it] 76%|███████▌ | 1423/1875 [9:19:18<2:52:08, 22.85s/it] 76%|███████▌ | 1424/1875 [9:19:44<2:57:03, 23.55s/it] 76%|███████▌ | 1425/1875 [9:20:05<2:51:17, 22.84s/it] 76%|███████▌ | 1426/1875 [9:20:25<2:43:56, 21.91s/it] 76%|███████▌ | 1427/1875 [9:20:52<2:54:49, 23.41s/it] 76%|███████▌ | 1428/1875 [9:21:14<2:51:16, 22.99s/it] 76%|███████▌ | 1429/1875 [9:21:36<2:50:18, 22.91s/it] 76%|███████▋ | 1430/1875 [9:21:58<2:47:19, 22.56s/it] 76%|███████▋ | 1430/1875 [9:21:58<2:47:19, 22.56s/it] 76%|███████▋ | 1431/1875 [9:22:19<2:44:02, 22.17s/it] 76%|███████▋ | 1432/1875 [9:22:44<2:48:50, 22.87s/it] 76%|███████▋ | 1433/1875 [9:23:06<2:47:37, 22.75s/it] 76%|███████▋ | 1434/1875 [9:23:26<2:40:15, 21.80s/it] 77%|███████▋ | 1435/1875 [9:23:49<2:43:07, 22.24s/it] 77%|███████▋ | 1436/1875 [9:24:10<2:39:56, 21.86s/it] 77%|███████▋ | 1437/1875 [9:24:31<2:36:41, 21.46s/it] 77%|███████▋ | 1438/1875 [9:24:54<2:39:36, 21.91s/it] 77%|███████▋ | 1439/1875 [9:25:18<2:45:06, 22.72s/it] 77%|███████▋ | 1440/1875 [9:25:37<2:36:01, 21.52s/it] 77%|███████▋ | 1440/1875 [9:25:37<2:36:01, 21.52s/it] 77%|███████▋ | 1441/1875 [9:25:58<2:33:53, 21.28s/it] 77%|███████▋ | 1442/1875 [9:26:24<2:45:30, 22.94s/it] 77%|███████▋ | 1443/1875 [9:26:45<2:40:37, 22.31s/it] 77%|███████▋ | 1444/1875 [9:27:07<2:38:41, 22.09s/it] 77%|███████▋ | 1445/1875 [9:27:33<2:48:10, 23.47s/it] 77%|███████▋ | 1446/1875 [9:27:52<2:37:55, 22.09s/it] 77%|███████▋ | 1447/1875 [9:28:16<2:41:26, 22.63s/it] 77%|███████▋ | 1448/1875 [9:28:37<2:38:03, 22.21s/it] 77%|███████▋ | 1449/1875 [9:28:50<2:16:21, 19.20s/it] 77%|███████▋ | 1450/1875 [9:29:16<2:30:14, 21.21s/it] 77%|███████▋ | 1450/1875 [9:29:16<2:30:14, 21.21s/it] 77%|███████▋ | 1451/1875 [9:29:36<2:29:08, 21.11s/it] 77%|███████▋ | 1452/1875 [9:30:00<2:34:04, 21.85s/it] 77%|███████▋ | 1453/1875 [9:30:21<2:32:12, 21.64s/it] 78%|███████▊ | 1454/1875 [9:30:45<2:35:52, 22.22s/it] 78%|███████▊ | 1455/1875 [9:31:10<2:41:08, 23.02s/it] 78%|███████▊ | 1456/1875 [9:31:30<2:36:07, 22.36s/it] 78%|███████▊ | 1457/1875 [9:31:51<2:32:47, 21.93s/it] 78%|███████▊ | 1458/1875 [9:32:10<2:25:20, 20.91s/it] 78%|███████▊ | 1459/1875 [9:32:27<2:17:56, 19.90s/it] 78%|███████▊ | 1460/1875 [9:32:50<2:22:27, 20.60s/it] 78%|███████▊ | 1460/1875 [9:32:50<2:22:27, 20.60s/it] 78%|███████▊ | 1461/1875 [9:33:16<2:33:40, 22.27s/it] 78%|███████▊ | 1462/1875 [9:33:35<2:27:37, 21.45s/it] 78%|███████▊ | 1463/1875 [9:33:54<2:22:28, 20.75s/it] 78%|███████▊ | 1464/1875 [9:34:13<2:18:27, 20.21s/it] 78%|███████▊ | 1465/1875 [9:34:38<2:27:03, 21.52s/it] 78%|███████▊ | 1466/1875 [9:35:00<2:28:20, 21.76s/it] 78%|███████▊ | 1467/1875 [9:35:19<2:21:18, 20.78s/it] 78%|███████▊ | 1468/1875 [9:35:42<2:25:08, 21.40s/it] 78%|███████▊ | 1469/1875 [9:36:05<2:28:35, 21.96s/it] 78%|███████▊ | 1470/1875 [9:36:27<2:29:19, 22.12s/it] 78%|███████▊ | 1470/1875 [9:36:27<2:29:19, 22.12s/it] 78%|███████▊ | 1471/1875 [9:37:02<2:53:49, 25.82s/it] 79%|███████▊ | 1472/1875 [9:37:25<2:47:46, 24.98s/it] 79%|███████▊ | 1473/1875 [9:37:44<2:36:01, 23.29s/it] 79%|███████▊ | 1474/1875 [9:38:08<2:35:57, 23.34s/it] 79%|███████▊ | 1475/1875 [9:38:27<2:26:58, 22.05s/it] 79%|███████▊ | 1476/1875 [9:38:51<2:30:11, 22.58s/it] 79%|███████▉ | 1477/1875 [9:39:14<2:31:00, 22.77s/it] 79%|███████▉ | 1478/1875 [9:39:35<2:27:32, 22.30s/it] 79%|███████▉ | 1479/1875 [9:39:57<2:25:52, 22.10s/it] 79%|███████▉ | 1480/1875 [9:40:16<2:20:03, 21.27s/it] 79%|███████▉ | 1480/1875 [9:40:16<2:20:03, 21.27s/it] 79%|███████▉ | 1481/1875 [9:40:32<2:08:30, 19.57s/it] 79%|███████▉ | 1482/1875 [9:40:53<2:12:29, 20.23s/it] 79%|███████▉ | 1483/1875 [9:41:12<2:08:16, 19.63s/it] 79%|███████▉ | 1484/1875 [9:41:39<2:23:20, 22.00s/it] 79%|███████▉ | 1485/1875 [9:42:02<2:24:17, 22.20s/it] 79%|███████▉ | 1486/1875 [9:42:25<2:26:31, 22.60s/it] 79%|███████▉ | 1487/1875 [9:42:48<2:25:27, 22.49s/it] 79%|███████▉ | 1488/1875 [9:43:07<2:19:15, 21.59s/it] 79%|███████▉ | 1489/1875 [9:43:27<2:16:29, 21.22s/it] 79%|███████▉ | 1490/1875 [9:43:48<2:15:25, 21.11s/it] 79%|███████▉ | 1490/1875 [9:43:48<2:15:25, 21.11s/it] 80%|███████▉ | 1491/1875 [9:44:08<2:11:45, 20.59s/it] 80%|███████▉ | 1492/1875 [9:44:27<2:09:35, 20.30s/it] 80%|███████▉ | 1493/1875 [9:44:45<2:04:53, 19.62s/it] 80%|███████▉ | 1494/1875 [9:45:07<2:08:50, 20.29s/it] 80%|███████▉ | 1495/1875 [9:45:31<2:15:50, 21.45s/it] 80%|███████▉ | 1496/1875 [9:45:49<2:09:05, 20.44s/it] 80%|███████▉ | 1497/1875 [9:46:11<2:11:39, 20.90s/it] 80%|███████▉ | 1498/1875 [9:46:30<2:06:45, 20.17s/it] 80%|███████▉ | 1499/1875 [9:46:47<2:00:18, 19.20s/it] 80%|████████ | 1500/1875 [9:47:07<2:01:15, 19.40s/it] 80%|████████ | 1500/1875 [9:47:07<2:01:15, 19.40s/it]{'loss': '0.01309', 'grad_norm': '5.906', 'learning_rate': '5e-06', 'num_tokens': '4.653e+07', 'completions/mean_length': '13.09', 'completions/min_length': '7', 'completions/max_length': '38.5', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '13.09', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '38.5', 'rewards/reward_function/mean': '-0.002791', 'rewards/reward_function/std': '0.7196', 'reward': '-0.002791', 'reward_std': '0.7196', 'frac_reward_zero_std': '0.6062', 'kl': '0.326', 'entropy': '0.1323', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '21.35', 'epoch': '2.256'} +{'loss': '0.02266', 'grad_norm': '9.438', 'learning_rate': '5e-06', 'num_tokens': '4.686e+07', 'completions/mean_length': '12.76', 'completions/min_length': '7', 'completions/max_length': '39', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '12.76', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '39', 'rewards/reward_function/mean': '0.001863', 'rewards/reward_function/std': '0.74', 'reward': '0.001863', 'reward_std': '0.74', 'frac_reward_zero_std': '0.6562', 'kl': '0.3052', 'entropy': '0.1481', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '20.98', 'epoch': '2.272'} +{'loss': '0.02141', 'grad_norm': '17.62', 'learning_rate': '5e-06', 'num_tokens': '4.721e+07', 'completions/mean_length': '14.27', 'completions/min_length': '7', 'completions/max_length': '40.8', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '14.27', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '40.8', 'rewards/reward_function/mean': '0.01933', 'rewards/reward_function/std': '0.7125', 'reward': '0.01933', 'reward_std': '0.7125', 'frac_reward_zero_std': '0.6375', 'kl': '0.3204', 'entropy': '0.1545', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '22.61', 'epoch': '2.288'} +{'loss': '0.01083', 'grad_norm': '8.25', 'learning_rate': '5e-06', 'num_tokens': '4.756e+07', 'completions/mean_length': '13.56', 'completions/min_length': '7', 'completions/max_length': '34.9', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '13.56', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '34.9', 'rewards/reward_function/mean': '0.08019', 'rewards/reward_function/std': '0.6901', 'reward': '0.08019', 'reward_std': '0.6901', 'frac_reward_zero_std': '0.6438', 'kl': '0.2528', 'entropy': '0.138', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '21.8', 'epoch': '2.304'} +{'loss': '0.02766', 'grad_norm': '5.75', 'learning_rate': '5e-06', 'num_tokens': '4.789e+07', 'completions/mean_length': '13.97', 'completions/min_length': '7', 'completions/max_length': '43.2', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '13.97', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '43.2', 'rewards/reward_function/mean': '0.05244', 'rewards/reward_function/std': '0.654', 'reward': '0.05244', 'reward_std': '0.654', 'frac_reward_zero_std': '0.6687', 'kl': '0.2631', 'entropy': '0.1535', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '21.78', 'epoch': '2.32'} +{'loss': '0.02207', 'grad_norm': '5.375', 'learning_rate': '5e-06', 'num_tokens': '4.821e+07', 'completions/mean_length': '12.77', 'completions/min_length': '7', 'completions/max_length': '34.7', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '12.77', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '34.7', 'rewards/reward_function/mean': '-0.07867', 'rewards/reward_function/std': '0.745', 'reward': '-0.07867', 'reward_std': '0.745', 'frac_reward_zero_std': '0.7188', 'kl': '0.3172', 'entropy': '0.1453', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '21.32', 'epoch': '2.336'} +{'loss': '0.01142', 'grad_norm': '6.312', 'learning_rate': '5e-06', 'num_tokens': '4.855e+07', 'completions/mean_length': '14.06', 'completions/min_length': '7', 'completions/max_length': '34.1', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '14.06', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '34.1', 'rewards/reward_function/mean': '0.1367', 'rewards/reward_function/std': '0.696', 'reward': '0.1367', 'reward_std': '0.696', 'frac_reward_zero_std': '0.6687', 'kl': '0.238', 'entropy': '0.1398', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '21.69', 'epoch': '2.352'} +{'loss': '0.03659', 'grad_norm': '18.5', 'learning_rate': '5e-06', 'num_tokens': '4.886e+07', 'completions/mean_length': '12.64', 'completions/min_length': '7', 'completions/max_length': '49.9', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '12.64', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '49.9', 'rewards/reward_function/mean': '0.000339', 'rewards/reward_function/std': '0.665', 'reward': '0.000339', 'reward_std': '0.665', 'frac_reward_zero_std': '0.6625', 'kl': '0.3018', 'entropy': '0.3134', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '22.77', 'epoch': '2.368'} +{'loss': '0.03261', 'grad_norm': '8.562', 'learning_rate': '5e-06', 'num_tokens': '4.916e+07', 'completions/mean_length': '13.5', 'completions/min_length': '7.1', 'completions/max_length': '43.8', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '13.5', 'completions/min_terminated_length': '7.1', 'completions/max_terminated_length': '43.8', 'rewards/reward_function/mean': '0.1307', 'rewards/reward_function/std': '0.6969', 'reward': '0.1307', 'reward_std': '0.6969', 'frac_reward_zero_std': '0.6125', 'kl': '0.2674', 'entropy': '0.1479', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '21.14', 'epoch': '2.384'} +{'loss': '0.01862', 'grad_norm': '3.125', 'learning_rate': '5e-06', 'num_tokens': '4.948e+07', 'completions/mean_length': '13.94', 'completions/min_length': '7', 'completions/max_length': '33.3', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '13.94', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '33.3', 'rewards/reward_function/mean': '0.1853', 'rewards/reward_function/std': '0.7265', 'reward': '0.1853', 'reward_std': '0.7265', 'frac_reward_zero_std': '0.6562', 'kl': '0.3222', 'entropy': '0.1258', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '19.75', 'epoch': '2.4'} + 80%|████████ | 1501/1875 [9:48:30<4:00:51, 38.64s/it] 80%|████████ | 1502/1875 [9:48:51<3:27:20, 33.35s/it] 80%|████████ | 1503/1875 [9:49:10<2:59:54, 29.02s/it] 80%|████████ | 1504/1875 [9:49:31<2:45:16, 26.73s/it] 80%|████████ | 1505/1875 [9:49:55<2:38:54, 25.77s/it] 80%|████████ | 1506/1875 [9:50:13<2:23:42, 23.37s/it] 80%|████████ | 1507/1875 [9:50:35<2:22:10, 23.18s/it] 80%|████████ | 1508/1875 [9:50:56<2:16:28, 22.31s/it] 80%|████████ | 1509/1875 [9:51:15<2:09:47, 21.28s/it] 81%|████████ | 1510/1875 [9:51:39<2:14:39, 22.13s/it] 81%|████████ | 1510/1875 [9:51:39<2:14:39, 22.13s/it] 81%|████████ | 1511/1875 [9:51:56<2:05:39, 20.71s/it] 81%|████████ | 1512/1875 [9:52:17<2:05:23, 20.72s/it] 81%|████████ | 1513/1875 [9:52:39<2:07:47, 21.18s/it] 81%|████████ | 1514/1875 [9:53:03<2:12:13, 21.98s/it] 81%|████████ | 1515/1875 [9:53:25<2:12:27, 22.08s/it] 81%|████████ | 1516/1875 [9:53:44<2:05:33, 20.98s/it] 81%|████████ | 1517/1875 [9:54:03<2:02:50, 20.59s/it] 81%|████████ | 1518/1875 [9:54:25<2:04:30, 20.93s/it] 81%|████████ | 1519/1875 [9:54:48<2:07:02, 21.41s/it] 81%|████████ | 1520/1875 [9:55:07<2:03:25, 20.86s/it] 81%|████████ | 1520/1875 [9:55:07<2:03:25, 20.86s/it] 81%|████████ | 1521/1875 [9:55:28<2:02:28, 20.76s/it] 81%|████████ | 1522/1875 [9:55:47<2:00:04, 20.41s/it] 81%|████████ | 1523/1875 [9:56:09<2:01:34, 20.72s/it] 81%|████████▏ | 1524/1875 [9:56:31<2:04:11, 21.23s/it] 81%|████████▏ | 1525/1875 [9:56:48<1:56:18, 19.94s/it] 81%|████████▏ | 1526/1875 [9:57:08<1:55:51, 19.92s/it] 81%|████████▏ | 1527/1875 [9:57:30<1:58:25, 20.42s/it] 81%|████████▏ | 1528/1875 [9:57:49<1:56:30, 20.15s/it] 82%|████████▏ | 1529/1875 [9:58:14<2:03:57, 21.49s/it] 82%|████████▏ | 1530/1875 [9:58:38<2:08:37, 22.37s/it] 82%|████████▏ | 1530/1875 [9:58:38<2:08:37, 22.37s/it] 82%|████████▏ | 1531/1875 [9:58:58<2:04:40, 21.75s/it] 82%|████████▏ | 1532/1875 [9:59:22<2:07:44, 22.34s/it] 82%|████████▏ | 1533/1875 [9:59:42<2:02:39, 21.52s/it] 82%|████████▏ | 1534/1875 [10:00:03<2:01:06, 21.31s/it] 82%|████████▏ | 1535/1875 [10:00:24<2:01:23, 21.42s/it] 82%|████████▏ | 1536/1875 [10:00:47<2:02:49, 21.74s/it] 82%|████████▏ | 1537/1875 [10:01:10<2:04:31, 22.10s/it] 82%|████████▏ | 1538/1875 [10:01:34<2:07:19, 22.67s/it] 82%|████████▏ | 1539/1875 [10:01:53<2:00:40, 21.55s/it] 82%|████████▏ | 1540/1875 [10:02:20<2:10:16, 23.33s/it] 82%|████████▏ | 1540/1875 [10:02:20<2:10:16, 23.33s/it] 82%|████████▏ | 1541/1875 [10:02:40<2:04:27, 22.36s/it] 82%|████████▏ | 1542/1875 [10:03:02<2:02:21, 22.05s/it] 82%|████████▏ | 1543/1875 [10:03:21<1:56:55, 21.13s/it] 82%|████████▏ | 1544/1875 [10:03:42<1:56:24, 21.10s/it] 82%|████████▏ | 1545/1875 [10:04:05<1:59:31, 21.73s/it] 82%|████████▏ | 1546/1875 [10:04:29<2:02:32, 22.35s/it] 83%|████████▎ | 1547/1875 [10:04:48<1:57:12, 21.44s/it] 83%|████████▎ | 1548/1875 [10:05:09<1:56:28, 21.37s/it] 83%|████████▎ | 1549/1875 [10:05:33<2:00:41, 22.21s/it] 83%|████████▎ | 1550/1875 [10:05:58<2:04:47, 23.04s/it] 83%|████████▎ | 1550/1875 [10:05:58<2:04:47, 23.04s/it] 83%|████████▎ | 1551/1875 [10:06:17<1:56:46, 21.63s/it] 83%|████████▎ | 1552/1875 [10:06:40<2:00:05, 22.31s/it] 83%|████████▎ | 1553/1875 [10:07:00<1:55:12, 21.47s/it] 83%|████████▎ | 1554/1875 [10:07:18<1:48:57, 20.37s/it] 83%|████████▎ | 1555/1875 [10:07:39<1:49:41, 20.57s/it] 83%|████████▎ | 1556/1875 [10:08:04<1:56:12, 21.86s/it] 83%|████████▎ | 1557/1875 [10:08:28<2:00:12, 22.68s/it] 83%|████████▎ | 1558/1875 [10:08:49<1:56:50, 22.11s/it] 83%|████████▎ | 1559/1875 [10:09:12<1:57:30, 22.31s/it] 83%|████████▎ | 1560/1875 [10:09:32<1:53:54, 21.70s/it] 83%|████████▎ | 1560/1875 [10:09:32<1:53:54, 21.70s/it] 83%|████████▎ | 1561/1875 [10:09:52<1:50:11, 21.05s/it] 83%|████████▎ | 1562/1875 [10:10:13<1:49:36, 21.01s/it] 83%|████████▎ | 1563/1875 [10:10:38<1:55:32, 22.22s/it] 83%|████████▎ | 1564/1875 [10:10:55<1:47:53, 20.82s/it] 83%|████████▎ | 1565/1875 [10:11:14<1:45:19, 20.39s/it] 84%|████████▎ | 1566/1875 [10:11:38<1:50:31, 21.46s/it] 84%|████████▎ | 1567/1875 [10:11:58<1:46:52, 20.82s/it] 84%|████████▎ | 1568/1875 [10:12:13<1:38:41, 19.29s/it] 84%|████████▎ | 1569/1875 [10:12:39<1:47:19, 21.04s/it] 84%|████████▎ | 1570/1875 [10:12:58<1:44:01, 20.46s/it] 84%|████████▎ | 1570/1875 [10:12:58<1:44:01, 20.46s/it] 84%|████████▍ | 1571/1875 [10:13:21<1:47:58, 21.31s/it] 84%|████████▍ | 1572/1875 [10:13:41<1:45:28, 20.89s/it] 84%|████████▍ | 1573/1875 [10:14:08<1:54:15, 22.70s/it] 84%|████████▍ | 1574/1875 [10:14:28<1:49:34, 21.84s/it] 84%|████████▍ | 1575/1875 [10:14:48<1:47:21, 21.47s/it] 84%|████████▍ | 1576/1875 [10:15:08<1:44:55, 21.06s/it] 84%|████████▍ | 1577/1875 [10:15:31<1:46:41, 21.48s/it] 84%|████████▍ | 1578/1875 [10:15:52<1:45:41, 21.35s/it] 84%|████████▍ | 1579/1875 [10:16:13<1:44:25, 21.17s/it] 84%|████████▍ | 1580/1875 [10:16:33<1:43:13, 20.99s/it] 84%|████████▍ | 1580/1875 [10:16:33<1:43:13, 20.99s/it] 84%|████████▍ | 1581/1875 [10:16:58<1:48:56, 22.23s/it] 84%|████████▍ | 1582/1875 [10:17:22<1:50:14, 22.58s/it] 84%|████████▍ | 1583/1875 [10:17:46<1:52:32, 23.13s/it] 84%|████████▍ | 1584/1875 [10:18:11<1:54:48, 23.67s/it] 85%|████████▍ | 1585/1875 [10:18:35<1:54:34, 23.70s/it] 85%|████████▍ | 1586/1875 [10:18:57<1:52:30, 23.36s/it] 85%|████████▍ | 1587/1875 [10:19:22<1:53:18, 23.60s/it] 85%|████████▍ | 1588/1875 [10:19:45<1:51:58, 23.41s/it] 85%|████████▍ | 1589/1875 [10:20:04<1:45:26, 22.12s/it] 85%|████████▍ | 1590/1875 [10:20:24<1:41:56, 21.46s/it] 85%|████████▍ | 1590/1875 [10:20:24<1:41:56, 21.46s/it] 85%|████████▍ | 1591/1875 [10:20:46<1:43:05, 21.78s/it] 85%|████████▍ | 1592/1875 [10:21:06<1:39:51, 21.17s/it] 85%|████████▍ | 1593/1875 [10:21:25<1:35:56, 20.41s/it] 85%|████████▌ | 1594/1875 [10:21:44<1:33:49, 20.03s/it] 85%|████████▌ | 1595/1875 [10:22:09<1:40:46, 21.59s/it] 85%|████████▌ | 1596/1875 [10:22:32<1:42:55, 22.13s/it] 85%|████████▌ | 1597/1875 [10:22:55<1:43:55, 22.43s/it] 85%|████████▌ | 1598/1875 [10:23:19<1:44:36, 22.66s/it] 85%|████████▌ | 1599/1875 [10:23:40<1:42:08, 22.20s/it] 85%|████████▌ | 1600/1875 [10:23:58<1:36:06, 20.97s/it] 85%|████████▌ | 1600/1875 [10:23:58<1:36:06, 20.97s/it]{'loss': '0.02142', 'grad_norm': '9.125', 'learning_rate': '5e-06', 'num_tokens': '4.981e+07', 'completions/mean_length': '12.73', 'completions/min_length': '7', 'completions/max_length': '26.4', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '12.73', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '26.4', 'rewards/reward_function/mean': '0.1588', 'rewards/reward_function/std': '0.6877', 'reward': '0.1588', 'reward_std': '0.6877', 'frac_reward_zero_std': '0.7125', 'kl': '0.2757', 'entropy': '0.09753', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '20.68', 'epoch': '2.416'} +{'loss': '0.01508', 'grad_norm': '5.094', 'learning_rate': '5e-06', 'num_tokens': '5.013e+07', 'completions/mean_length': '13.03', 'completions/min_length': '7', 'completions/max_length': '31.4', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '13.03', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '31.4', 'rewards/reward_function/mean': '0.05061', 'rewards/reward_function/std': '0.7282', 'reward': '0.05061', 'reward_std': '0.7282', 'frac_reward_zero_std': '0.7063', 'kl': '0.2969', 'entropy': '0.1369', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '20.76', 'epoch': '2.432'} +{'loss': '0.01224', 'grad_norm': '5.219', 'learning_rate': '5e-06', 'num_tokens': '5.047e+07', 'completions/mean_length': '14.53', 'completions/min_length': '7', 'completions/max_length': '41.6', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '14.53', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '41.6', 'rewards/reward_function/mean': '0.06232', 'rewards/reward_function/std': '0.7109', 'reward': '0.06232', 'reward_std': '0.7109', 'frac_reward_zero_std': '0.6375', 'kl': '0.2414', 'entropy': '0.168', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '21.01', 'epoch': '2.448'} +{'loss': '0.04778', 'grad_norm': '31.12', 'learning_rate': '5e-06', 'num_tokens': '5.081e+07', 'completions/mean_length': '13.61', 'completions/min_length': '7', 'completions/max_length': '42.9', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '13.61', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '42.9', 'rewards/reward_function/mean': '-0.04963', 'rewards/reward_function/std': '0.7142', 'reward': '-0.04963', 'reward_std': '0.7142', 'frac_reward_zero_std': '0.625', 'kl': '0.3074', 'entropy': '0.1685', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '22.11', 'epoch': '2.464'} +{'loss': '0.01889', 'grad_norm': '10.88', 'learning_rate': '5e-06', 'num_tokens': '5.113e+07', 'completions/mean_length': '14.57', 'completions/min_length': '7', 'completions/max_length': '45.6', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '14.57', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '45.6', 'rewards/reward_function/mean': '0.07862', 'rewards/reward_function/std': '0.6899', 'reward': '0.07862', 'reward_std': '0.6899', 'frac_reward_zero_std': '0.6687', 'kl': '0.2354', 'entropy': '0.1813', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '21.72', 'epoch': '2.48'} +{'loss': '0.01917', 'grad_norm': '14.69', 'learning_rate': '5e-06', 'num_tokens': '5.146e+07', 'completions/mean_length': '14.26', 'completions/min_length': '7', 'completions/max_length': '37.3', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '14.26', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '37.3', 'rewards/reward_function/mean': '0.1421', 'rewards/reward_function/std': '0.6414', 'reward': '0.1421', 'reward_std': '0.6414', 'frac_reward_zero_std': '0.6125', 'kl': '0.2192', 'entropy': '0.1491', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '21.3', 'epoch': '2.496'} +{'loss': '0.03111', 'grad_norm': '7.156', 'learning_rate': '5e-06', 'num_tokens': '5.179e+07', 'completions/mean_length': '13.41', 'completions/min_length': '7', 'completions/max_length': '32.3', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '13.41', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '32.3', 'rewards/reward_function/mean': '0.1235', 'rewards/reward_function/std': '0.6793', 'reward': '0.1235', 'reward_std': '0.6793', 'frac_reward_zero_std': '0.7125', 'kl': '0.2231', 'entropy': '0.1278', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '20.48', 'epoch': '2.512'} +{'loss': '0.02517', 'grad_norm': '6.5', 'learning_rate': '5e-06', 'num_tokens': '5.21e+07', 'completions/mean_length': '13.18', 'completions/min_length': '7', 'completions/max_length': '44', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '13.18', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '44', 'rewards/reward_function/mean': '0.003831', 'rewards/reward_function/std': '0.6929', 'reward': '0.003831', 'reward_std': '0.6929', 'frac_reward_zero_std': '0.5938', 'kl': '0.3043', 'entropy': '0.1689', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '21.46', 'epoch': '2.528'} +{'loss': '0.02492', 'grad_norm': '5.062', 'learning_rate': '5e-06', 'num_tokens': '5.244e+07', 'completions/mean_length': '13.94', 'completions/min_length': '7', 'completions/max_length': '40.7', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '13.94', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '40.7', 'rewards/reward_function/mean': '0.07226', 'rewards/reward_function/std': '0.6784', 'reward': '0.07226', 'reward_std': '0.6784', 'frac_reward_zero_std': '0.6125', 'kl': '0.2664', 'entropy': '0.1399', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '22.95', 'epoch': '2.544'} +{'loss': '0.04155', 'grad_norm': '17.12', 'learning_rate': '5e-06', 'num_tokens': '5.277e+07', 'completions/mean_length': '12.71', 'completions/min_length': '7', 'completions/max_length': '34.9', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '12.71', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '34.9', 'rewards/reward_function/mean': '0.1224', 'rewards/reward_function/std': '0.7034', 'reward': '0.1224', 'reward_std': '0.7034', 'frac_reward_zero_std': '0.7125', 'kl': '0.3138', 'entropy': '0.1366', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '21.31', 'epoch': '2.56'} + 85%|████████▌ | 1601/1875 [10:25:27<3:08:55, 41.37s/it] 85%|████████▌ | 1602/1875 [10:25:52<2:45:28, 36.37s/it] 85%|████████▌ | 1603/1875 [10:26:13<2:24:01, 31.77s/it] 86%|████████▌ | 1604/1875 [10:26:39<2:15:47, 30.07s/it] 86%|████████▌ | 1605/1875 [10:27:01<2:05:08, 27.81s/it] 86%|████████▌ | 1606/1875 [10:27:23<1:56:09, 25.91s/it] 86%|████████▌ | 1607/1875 [10:27:40<1:44:48, 23.46s/it] 86%|████████▌ | 1608/1875 [10:28:04<1:44:00, 23.37s/it] 86%|████████▌ | 1609/1875 [10:28:27<1:43:29, 23.34s/it] 86%|████████▌ | 1610/1875 [10:28:52<1:45:14, 23.83s/it] 86%|████████▌ | 1610/1875 [10:28:52<1:45:14, 23.83s/it] 86%|████████▌ | 1611/1875 [10:29:15<1:44:27, 23.74s/it] 86%|████████▌ | 1612/1875 [10:29:32<1:35:22, 21.76s/it] 86%|████████▌ | 1613/1875 [10:29:51<1:31:24, 20.93s/it] 86%|████████▌ | 1614/1875 [10:30:17<1:36:59, 22.30s/it] 86%|████████▌ | 1615/1875 [10:30:43<1:41:34, 23.44s/it] 86%|████████▌ | 1616/1875 [10:31:03<1:36:23, 22.33s/it] 86%|████████▌ | 1617/1875 [10:31:27<1:38:03, 22.80s/it] 86%|████████▋ | 1618/1875 [10:31:51<1:38:56, 23.10s/it] 86%|████████▋ | 1619/1875 [10:32:10<1:34:11, 22.08s/it] 86%|████████▋ | 1620/1875 [10:32:31<1:31:37, 21.56s/it] 86%|████████▋ | 1620/1875 [10:32:31<1:31:37, 21.56s/it] 86%|████████▋ | 1621/1875 [10:32:55<1:35:14, 22.50s/it] 87%|████████▋ | 1622/1875 [10:33:13<1:29:28, 21.22s/it] 87%|████████▋ | 1623/1875 [10:33:39<1:34:57, 22.61s/it] 87%|████████▋ | 1624/1875 [10:33:59<1:31:19, 21.83s/it] 87%|████████▋ | 1625/1875 [10:34:23<1:32:54, 22.30s/it] 87%|████████▋ | 1626/1875 [10:34:45<1:32:16, 22.23s/it] 87%|████████▋ | 1627/1875 [10:35:04<1:28:14, 21.35s/it] 87%|████████▋ | 1628/1875 [10:35:35<1:39:35, 24.19s/it] 87%|████████▋ | 1629/1875 [10:35:53<1:32:04, 22.46s/it] 87%|████████▋ | 1630/1875 [10:36:15<1:30:17, 22.11s/it] 87%|████████▋ | 1630/1875 [10:36:15<1:30:17, 22.11s/it] 87%|████████▋ | 1631/1875 [10:36:33<1:25:18, 20.98s/it] 87%|████████▋ | 1632/1875 [10:36:56<1:27:41, 21.65s/it] 87%|████████▋ | 1633/1875 [10:37:18<1:27:40, 21.74s/it] 87%|████████▋ | 1634/1875 [10:37:36<1:22:26, 20.52s/it] 87%|████████▋ | 1635/1875 [10:37:54<1:18:58, 19.74s/it] 87%|████████▋ | 1636/1875 [10:38:14<1:19:06, 19.86s/it] 87%|████████▋ | 1637/1875 [10:38:32<1:16:57, 19.40s/it] 87%|████████▋ | 1638/1875 [10:38:54<1:19:37, 20.16s/it] 87%|████████▋ | 1639/1875 [10:39:15<1:20:16, 20.41s/it] 87%|████████▋ | 1640/1875 [10:39:33<1:17:27, 19.78s/it] 87%|████████▋ | 1640/1875 [10:39:33<1:17:27, 19.78s/it] 88%|████████▊ | 1641/1875 [10:39:53<1:16:29, 19.61s/it] 88%|████████▊ | 1642/1875 [10:40:13<1:16:29, 19.70s/it] 88%|████████▊ | 1643/1875 [10:40:34<1:17:59, 20.17s/it] 88%|████████▊ | 1644/1875 [10:40:54<1:17:41, 20.18s/it] 88%|████████▊ | 1645/1875 [10:41:14<1:17:10, 20.13s/it] 88%|████████▊ | 1646/1875 [10:41:39<1:22:46, 21.69s/it] 88%|████████▊ | 1647/1875 [10:42:04<1:25:23, 22.47s/it] 88%|████████▊ | 1648/1875 [10:42:23<1:21:03, 21.42s/it] 88%|████████▊ | 1649/1875 [10:42:43<1:19:11, 21.02s/it] 88%|████████▊ | 1650/1875 [10:42:59<1:13:38, 19.64s/it] 88%|████████▊ | 1650/1875 [10:42:59<1:13:38, 19.64s/it] 88%|████████▊ | 1651/1875 [10:43:21<1:15:41, 20.27s/it] 88%|████████▊ | 1652/1875 [10:43:46<1:21:15, 21.86s/it] 88%|████████▊ | 1653/1875 [10:44:14<1:26:38, 23.41s/it] 88%|████████▊ | 1654/1875 [10:44:33<1:22:15, 22.33s/it] 88%|████████▊ | 1655/1875 [10:44:52<1:17:59, 21.27s/it] 88%|████████▊ | 1656/1875 [10:45:12<1:15:41, 20.74s/it] 88%|████████▊ | 1657/1875 [10:45:31<1:13:55, 20.35s/it] 88%|████████▊ | 1658/1875 [10:45:58<1:20:24, 22.23s/it] 88%|████████▊ | 1659/1875 [10:46:15<1:15:13, 20.89s/it] 89%|████████▊ | 1660/1875 [10:46:33<1:11:23, 19.92s/it] 89%|████████▊ | 1660/1875 [10:46:33<1:11:23, 19.92s/it] 89%|████████▊ | 1661/1875 [10:46:55<1:13:21, 20.57s/it] 89%|████████▊ | 1662/1875 [10:47:18<1:15:28, 21.26s/it] 89%|████████▊ | 1663/1875 [10:47:40<1:16:03, 21.52s/it] 89%|████████▊ | 1664/1875 [10:48:05<1:18:57, 22.45s/it] 89%|████████▉ | 1665/1875 [10:48:24<1:15:15, 21.50s/it] 89%|████████▉ | 1666/1875 [10:48:46<1:15:45, 21.75s/it] 89%|████████▉ | 1667/1875 [10:49:03<1:10:10, 20.24s/it] 89%|████████▉ | 1668/1875 [10:49:23<1:09:12, 20.06s/it] 89%|████████▉ | 1669/1875 [10:49:45<1:11:17, 20.77s/it] 89%|████████▉ | 1670/1875 [10:50:02<1:07:04, 19.63s/it] 89%|████████▉ | 1670/1875 [10:50:02<1:07:04, 19.63s/it] 89%|████████▉ | 1671/1875 [10:50:21<1:06:24, 19.53s/it] 89%|████████▉ | 1672/1875 [10:50:46<1:11:02, 21.00s/it] 89%|████████▉ | 1673/1875 [10:51:08<1:12:04, 21.41s/it] 89%|████████▉ | 1674/1875 [10:51:27<1:09:25, 20.72s/it] 89%|████████▉ | 1675/1875 [10:51:49<1:09:27, 20.84s/it] 89%|████████▉ | 1676/1875 [10:52:09<1:08:44, 20.73s/it] 89%|████████▉ | 1677/1875 [10:52:30<1:08:25, 20.74s/it] 89%|████████▉ | 1678/1875 [10:52:51<1:08:41, 20.92s/it] 90%|████████▉ | 1679/1875 [10:53:11<1:06:54, 20.48s/it] 90%|████████▉ | 1680/1875 [10:53:34<1:09:42, 21.45s/it] 90%|████████▉ | 1680/1875 [10:53:34<1:09:42, 21.45s/it] 90%|████████▉ | 1681/1875 [10:53:58<1:12:01, 22.28s/it] 90%|████████▉ | 1682/1875 [10:54:18<1:08:36, 21.33s/it] 90%|████████▉ | 1683/1875 [10:54:38<1:07:04, 20.96s/it] 90%|████████▉ | 1684/1875 [10:55:00<1:07:36, 21.24s/it] 90%|████████▉ | 1685/1875 [10:55:18<1:04:54, 20.50s/it] 90%|████████▉ | 1686/1875 [10:55:41<1:06:34, 21.13s/it] 90%|████████▉ | 1687/1875 [10:56:03<1:07:10, 21.44s/it] 90%|█████████ | 1688/1875 [10:56:24<1:06:02, 21.19s/it] 90%|█████████ | 1689/1875 [10:56:42<1:03:18, 20.42s/it] 90%|█████████ | 1690/1875 [10:57:04<1:03:58, 20.75s/it] 90%|█████████ | 1690/1875 [10:57:04<1:03:58, 20.75s/it] 90%|█████████ | 1691/1875 [10:57:25<1:04:21, 20.99s/it] 90%|█████████ | 1692/1875 [10:57:43<1:01:13, 20.07s/it] 90%|█████████ | 1693/1875 [10:58:05<1:02:36, 20.64s/it] 90%|█████████ | 1694/1875 [10:58:29<1:05:04, 21.57s/it] 90%|█████████ | 1695/1875 [10:58:49<1:03:15, 21.08s/it] 90%|█████████ | 1696/1875 [10:59:10<1:03:12, 21.19s/it] 91%|█████████ | 1697/1875 [10:59:30<1:01:02, 20.58s/it] 91%|█████████ | 1698/1875 [10:59:54<1:04:18, 21.80s/it] 91%|█████████ | 1699/1875 [11:00:17<1:05:02, 22.17s/it] 91%|█████████ | 1700/1875 [11:00:40<1:05:02, 22.30s/it] 91%|█████████ | 1700/1875 [11:00:40<1:05:02, 22.30s/it]{'loss': '0.0282', 'grad_norm': '6.906', 'learning_rate': '5e-06', 'num_tokens': '5.311e+07', 'completions/mean_length': '13.45', 'completions/min_length': '7', 'completions/max_length': '37', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '13.45', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '37', 'rewards/reward_function/mean': '0.1105', 'rewards/reward_function/std': '0.7144', 'reward': '0.1105', 'reward_std': '0.7144', 'frac_reward_zero_std': '0.6875', 'kl': '0.2986', 'entropy': '0.1074', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '22.85', 'epoch': '2.576'} +{'loss': '0.01987', 'grad_norm': '7.125', 'learning_rate': '5e-06', 'num_tokens': '5.347e+07', 'completions/mean_length': '13.31', 'completions/min_length': '7', 'completions/max_length': '29.5', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '13.31', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '29.5', 'rewards/reward_function/mean': '0.07844', 'rewards/reward_function/std': '0.7042', 'reward': '0.07844', 'reward_std': '0.7042', 'frac_reward_zero_std': '0.6875', 'kl': '0.2848', 'entropy': '0.1307', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '21.75', 'epoch': '2.592'} +{'loss': '0.02329', 'grad_norm': '3.906', 'learning_rate': '5e-06', 'num_tokens': '5.381e+07', 'completions/mean_length': '14.04', 'completions/min_length': '7', 'completions/max_length': '38', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '14.04', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '38', 'rewards/reward_function/mean': '0.1738', 'rewards/reward_function/std': '0.6654', 'reward': '0.1738', 'reward_std': '0.6654', 'frac_reward_zero_std': '0.7188', 'kl': '0.2736', 'entropy': '0.199', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '22.32', 'epoch': '2.608'} +{'loss': '0.01978', 'grad_norm': '4.312', 'learning_rate': '5e-06', 'num_tokens': '5.411e+07', 'completions/mean_length': '12.25', 'completions/min_length': '7', 'completions/max_length': '26.2', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '12.25', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '26.2', 'rewards/reward_function/mean': '0.04388', 'rewards/reward_function/std': '0.7204', 'reward': '0.04388', 'reward_std': '0.7204', 'frac_reward_zero_std': '0.6312', 'kl': '0.308', 'entropy': '0.1204', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '19.8', 'epoch': '2.624'} +{'loss': '0.01952', 'grad_norm': '6.156', 'learning_rate': '5e-06', 'num_tokens': '5.444e+07', 'completions/mean_length': '13.17', 'completions/min_length': '7', 'completions/max_length': '33.8', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '13.17', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '33.8', 'rewards/reward_function/mean': '0.1095', 'rewards/reward_function/std': '0.652', 'reward': '0.1095', 'reward_std': '0.652', 'frac_reward_zero_std': '0.6438', 'kl': '0.2874', 'entropy': '0.1161', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '20.48', 'epoch': '2.64'} +{'loss': '0.01767', 'grad_norm': '3.031', 'learning_rate': '5e-06', 'num_tokens': '5.478e+07', 'completions/mean_length': '13.54', 'completions/min_length': '7', 'completions/max_length': '31.5', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '13.54', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '31.5', 'rewards/reward_function/mean': '0.1046', 'rewards/reward_function/std': '0.6983', 'reward': '0.1046', 'reward_std': '0.6983', 'frac_reward_zero_std': '0.6687', 'kl': '0.2821', 'entropy': '0.1175', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '21.31', 'epoch': '2.656'} +{'loss': '0.008238', 'grad_norm': '4.125', 'learning_rate': '5e-06', 'num_tokens': '5.509e+07', 'completions/mean_length': '12.68', 'completions/min_length': '7', 'completions/max_length': '35.9', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '12.68', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '35.9', 'rewards/reward_function/mean': '0.09318', 'rewards/reward_function/std': '0.6847', 'reward': '0.09318', 'reward_std': '0.6847', 'frac_reward_zero_std': '0.6813', 'kl': '0.309', 'entropy': '0.1406', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '20.82', 'epoch': '2.672'} +{'loss': '0.0224', 'grad_norm': '7.188', 'learning_rate': '5e-06', 'num_tokens': '5.542e+07', 'completions/mean_length': '12.99', 'completions/min_length': '7', 'completions/max_length': '36.5', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '12.99', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '36.5', 'rewards/reward_function/mean': '0.0522', 'rewards/reward_function/std': '0.6591', 'reward': '0.0522', 'reward_std': '0.6591', 'frac_reward_zero_std': '0.6562', 'kl': '0.2935', 'entropy': '0.135', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '21.12', 'epoch': '2.688'} +{'loss': '0.02342', 'grad_norm': '4.719', 'learning_rate': '5e-06', 'num_tokens': '5.576e+07', 'completions/mean_length': '13.21', 'completions/min_length': '7', 'completions/max_length': '34', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '13.21', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '34', 'rewards/reward_function/mean': '0.07525', 'rewards/reward_function/std': '0.6955', 'reward': '0.07525', 'reward_std': '0.6955', 'frac_reward_zero_std': '0.6312', 'kl': '0.2759', 'entropy': '0.1333', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '20.87', 'epoch': '2.704'} +{'loss': '0.01017', 'grad_norm': '7.5', 'learning_rate': '5e-06', 'num_tokens': '5.608e+07', 'completions/mean_length': '13.08', 'completions/min_length': '7', 'completions/max_length': '32.5', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '13.08', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '32.5', 'rewards/reward_function/mean': '0.122', 'rewards/reward_function/std': '0.6492', 'reward': '0.122', 'reward_std': '0.6492', 'frac_reward_zero_std': '0.6875', 'kl': '0.2605', 'entropy': '0.1262', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '21.51', 'epoch': '2.72'} + 91%|█████████ | 1701/1875 [11:02:05<1:59:46, 41.30s/it] 91%|█████████ | 1702/1875 [11:02:25<1:39:54, 34.65s/it] 91%|█████████ | 1703/1875 [11:02:47<1:28:45, 30.96s/it] 91%|█████████ | 1704/1875 [11:03:06<1:18:06, 27.41s/it] 91%|█████████ | 1705/1875 [11:03:29<1:14:09, 26.18s/it] 91%|█████████ | 1706/1875 [11:03:49<1:08:04, 24.17s/it] 91%|█████████ | 1707/1875 [11:04:05<1:01:03, 21.80s/it] 91%|█████████ | 1708/1875 [11:04:21<56:05, 20.15s/it] 91%|█████████ | 1709/1875 [11:04:47<1:00:12, 21.76s/it] 91%|█████████ | 1710/1875 [11:05:11<1:02:04, 22.57s/it] 91%|█████████ | 1710/1875 [11:05:11<1:02:04, 22.57s/it] 91%|█████████▏| 1711/1875 [11:05:32<1:00:04, 21.98s/it] 91%|█████████▏| 1712/1875 [11:05:55<1:00:51, 22.40s/it] 91%|█████████▏| 1713/1875 [11:06:14<57:05, 21.14s/it] 91%|█████████▏| 1714/1875 [11:06:33<55:22, 20.64s/it] 91%|█████████▏| 1715/1875 [11:06:56<56:38, 21.24s/it] 92%|█████████▏| 1716/1875 [11:07:20<58:49, 22.20s/it] 92%|█████████▏| 1717/1875 [11:07:42<58:19, 22.15s/it] 92%|█████████▏| 1718/1875 [11:08:03<56:40, 21.66s/it] 92%|█████████▏| 1719/1875 [11:08:22<54:21, 20.90s/it] 92%|█████████▏| 1720/1875 [11:08:47<57:09, 22.13s/it] 92%|█████████▏| 1720/1875 [11:08:47<57:09, 22.13s/it] 92%|█████████▏| 1721/1875 [11:09:08<55:56, 21.79s/it] 92%|█████████▏| 1722/1875 [11:09:28<54:09, 21.24s/it] 92%|█████████▏| 1723/1875 [11:09:52<55:46, 22.02s/it] 92%|█████████▏| 1724/1875 [11:10:17<57:47, 22.96s/it] 92%|█████████▏| 1725/1875 [11:10:35<53:33, 21.42s/it] 92%|█████████▏| 1726/1875 [11:10:59<55:21, 22.29s/it] 92%|█████████▏| 1727/1875 [11:11:16<51:17, 20.79s/it] 92%|█████████▏| 1728/1875 [11:11:40<53:21, 21.78s/it] 92%|█████████▏| 1729/1875 [11:12:01<51:54, 21.34s/it] 92%|█████████▏| 1730/1875 [11:12:19<49:13, 20.37s/it] 92%|█████████▏| 1730/1875 [11:12:19<49:13, 20.37s/it] 92%|█████████▏| 1731/1875 [11:12:44<52:10, 21.74s/it] 92%|█████████▏| 1732/1875 [11:13:04<50:40, 21.26s/it] 92%|█████████▏| 1733/1875 [11:13:23<48:54, 20.67s/it] 92%|█████████▏| 1734/1875 [11:13:44<48:54, 20.81s/it] 93%|█████████▎| 1735/1875 [11:14:03<46:57, 20.12s/it] 93%|█████████▎| 1736/1875 [11:14:28<49:51, 21.52s/it] 93%|█████████▎| 1737/1875 [11:14:47<47:50, 20.80s/it] 93%|█████████▎| 1738/1875 [11:15:10<49:21, 21.62s/it] 93%|█████████▎| 1739/1875 [11:15:37<52:12, 23.04s/it] 93%|█████████▎| 1740/1875 [11:16:01<52:42, 23.43s/it] 93%|█████████▎| 1740/1875 [11:16:01<52:42, 23.43s/it] 93%|█████████▎| 1741/1875 [11:16:26<53:09, 23.80s/it] 93%|█████████▎| 1742/1875 [11:16:46<50:32, 22.80s/it] 93%|█████████▎| 1743/1875 [11:17:07<48:55, 22.24s/it] 93%|█████████▎| 1744/1875 [11:17:32<50:06, 22.95s/it] 93%|█████████▎| 1745/1875 [11:17:52<48:09, 22.23s/it] 93%|█████████▎| 1746/1875 [11:18:12<46:16, 21.52s/it] 93%|█████████▎| 1747/1875 [11:18:32<44:47, 20.99s/it] 93%|█████████▎| 1748/1875 [11:18:52<44:14, 20.90s/it] 93%|█████████▎| 1749/1875 [11:19:18<46:48, 22.29s/it] 93%|█████████▎| 1750/1875 [11:19:41<46:55, 22.53s/it] 93%|█████████▎| 1750/1875 [11:19:41<46:55, 22.53s/it] 93%|█████████▎| 1751/1875 [11:19:58<42:49, 20.72s/it] 93%|█████████▎| 1752/1875 [11:20:21<44:22, 21.64s/it] 93%|█████████▎| 1753/1875 [11:20:42<43:27, 21.37s/it] 94%|█████████▎| 1754/1875 [11:21:04<43:12, 21.42s/it] 94%|█████████▎| 1755/1875 [11:21:25<42:58, 21.48s/it] 94%|█████████▎| 1756/1875 [11:21:44<41:03, 20.70s/it] 94%|█████████▎| 1757/1875 [11:22:04<40:18, 20.50s/it] 94%|█████████▍| 1758/1875 [11:22:23<39:07, 20.06s/it] 94%|█████████▍| 1759/1875 [11:22:43<38:36, 19.97s/it] 94%|█████████▍| 1760/1875 [11:23:06<39:55, 20.83s/it] 94%|█████████▍| 1760/1875 [11:23:06<39:55, 20.83s/it] 94%|█████████▍| 1761/1875 [11:23:29<40:54, 21.53s/it] 94%|█████████▍| 1762/1875 [11:23:47<38:24, 20.40s/it] 94%|█████████▍| 1763/1875 [11:24:13<41:30, 22.24s/it] 94%|█████████▍| 1764/1875 [11:24:34<40:16, 21.77s/it] 94%|█████████▍| 1765/1875 [11:24:57<40:34, 22.13s/it] 94%|█████████▍| 1766/1875 [11:25:23<42:16, 23.27s/it] 94%|█████████▍| 1767/1875 [11:25:47<42:35, 23.66s/it] 94%|█████████▍| 1768/1875 [11:26:06<39:38, 22.23s/it] 94%|█████████▍| 1769/1875 [11:26:28<39:03, 22.11s/it] 94%|█████████▍| 1770/1875 [11:26:45<35:57, 20.55s/it] 94%|█████████▍| 1770/1875 [11:26:45<35:57, 20.55s/it] 94%|█████████▍| 1771/1875 [11:27:07<36:37, 21.13s/it] 95%|█████████▍| 1772/1875 [11:27:28<36:03, 21.00s/it] 95%|█████████▍| 1773/1875 [11:27:51<36:27, 21.44s/it] 95%|█████████▍| 1774/1875 [11:28:12<36:04, 21.43s/it] 95%|█████████▍| 1775/1875 [11:28:30<33:58, 20.38s/it] 95%|█████████▍| 1776/1875 [11:28:54<35:11, 21.33s/it] 95%|█████████▍| 1777/1875 [11:29:17<35:57, 22.01s/it] 95%|█████████▍| 1778/1875 [11:29:41<36:22, 22.50s/it] 95%|█████████▍| 1779/1875 [11:30:02<35:35, 22.24s/it] 95%|█████████▍| 1780/1875 [11:30:22<33:56, 21.44s/it] 95%|█████████▍| 1780/1875 [11:30:22<33:56, 21.44s/it] 95%|█████████▍| 1781/1875 [11:30:46<34:52, 22.26s/it] 95%|█████████▌| 1782/1875 [11:31:11<35:49, 23.12s/it] 95%|█████████▌| 1783/1875 [11:31:34<35:21, 23.06s/it] 95%|█████████▌| 1784/1875 [11:31:59<35:50, 23.63s/it] 95%|█████████▌| 1785/1875 [11:32:19<33:53, 22.59s/it] 95%|█████████▌| 1786/1875 [11:32:41<33:02, 22.27s/it] 95%|█████████▌| 1787/1875 [11:33:02<32:12, 21.96s/it] 95%|█████████▌| 1788/1875 [11:33:21<30:26, 20.99s/it] 95%|█████████▌| 1789/1875 [11:33:40<29:21, 20.49s/it] 95%|█████████▌| 1790/1875 [11:34:05<31:01, 21.90s/it] 95%|█████████▌| 1790/1875 [11:34:05<31:01, 21.90s/it] 96%|█████████▌| 1791/1875 [11:34:26<30:12, 21.57s/it] 96%|█████████▌| 1792/1875 [11:34:47<29:32, 21.36s/it] 96%|█████████▌| 1793/1875 [11:35:11<30:20, 22.20s/it] 96%|█████████▌| 1794/1875 [11:35:36<31:09, 23.08s/it] 96%|█████████▌| 1795/1875 [11:35:57<29:46, 22.33s/it] 96%|█████████▌| 1796/1875 [11:36:19<29:18, 22.26s/it] 96%|█████████▌| 1797/1875 [11:36:39<27:54, 21.47s/it] 96%|█████████▌| 1798/1875 [11:36:58<26:40, 20.79s/it] 96%|█████████▌| 1799/1875 [11:37:14<24:45, 19.55s/it] 96%|█████████▌| 1800/1875 [11:37:36<25:04, 20.06s/it] 96%|█████████▌| 1800/1875 [11:37:36<25:04, 20.06s/it]{'loss': '0.02211', 'grad_norm': '11.19', 'learning_rate': '5e-06', 'num_tokens': '5.639e+07', 'completions/mean_length': '12.74', 'completions/min_length': '7', 'completions/max_length': '35.9', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '12.74', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '35.9', 'rewards/reward_function/mean': '0.0001306', 'rewards/reward_function/std': '0.7266', 'reward': '0.0001306', 'reward_std': '0.7266', 'frac_reward_zero_std': '0.6125', 'kl': '0.3043', 'entropy': '0.1455', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '20.74', 'epoch': '2.736'} +{'loss': '0.009366', 'grad_norm': '4.188', 'learning_rate': '5e-06', 'num_tokens': '5.674e+07', 'completions/mean_length': '13.39', 'completions/min_length': '7', 'completions/max_length': '31.1', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '13.39', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '31.1', 'rewards/reward_function/mean': '0.2033', 'rewards/reward_function/std': '0.6653', 'reward': '0.2033', 'reward_std': '0.6653', 'frac_reward_zero_std': '0.725', 'kl': '0.2526', 'entropy': '0.1014', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '21.45', 'epoch': '2.752'} +{'loss': '0.01678', 'grad_norm': '5.5', 'learning_rate': '5e-06', 'num_tokens': '5.706e+07', 'completions/mean_length': '13.94', 'completions/min_length': '7', 'completions/max_length': '36.5', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '13.94', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '36.5', 'rewards/reward_function/mean': '0.07316', 'rewards/reward_function/std': '0.7211', 'reward': '0.07316', 'reward_std': '0.7211', 'frac_reward_zero_std': '0.6625', 'kl': '0.2483', 'entropy': '0.1911', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '21.11', 'epoch': '2.768'} +{'loss': '0.02595', 'grad_norm': '8.438', 'learning_rate': '5e-06', 'num_tokens': '5.74e+07', 'completions/mean_length': '14.16', 'completions/min_length': '7', 'completions/max_length': '42.7', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '14.16', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '42.7', 'rewards/reward_function/mean': '-0.01021', 'rewards/reward_function/std': '0.6986', 'reward': '-0.01021', 'reward_std': '0.6986', 'frac_reward_zero_std': '0.65', 'kl': '0.2393', 'entropy': '0.1659', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '22.12', 'epoch': '2.784'} +{'loss': '0.04412', 'grad_norm': '7.25', 'learning_rate': '5e-06', 'num_tokens': '5.772e+07', 'completions/mean_length': '13.08', 'completions/min_length': '7', 'completions/max_length': '43.6', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '13.08', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '43.6', 'rewards/reward_function/mean': '-0.01071', 'rewards/reward_function/std': '0.7197', 'reward': '-0.01071', 'reward_std': '0.7197', 'frac_reward_zero_std': '0.6438', 'kl': '0.2781', 'entropy': '0.1791', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '21.93', 'epoch': '2.8'} +{'loss': '0.01434', 'grad_norm': '5.219', 'learning_rate': '5e-06', 'num_tokens': '5.805e+07', 'completions/mean_length': '12.94', 'completions/min_length': '7', 'completions/max_length': '27.6', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '12.94', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '27.6', 'rewards/reward_function/mean': '0.1241', 'rewards/reward_function/std': '0.6983', 'reward': '0.1241', 'reward_std': '0.6983', 'frac_reward_zero_std': '0.7188', 'kl': '0.2688', 'entropy': '0.1184', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '20.39', 'epoch': '2.816'} +{'loss': '0.01871', 'grad_norm': '9', 'learning_rate': '5e-06', 'num_tokens': '5.838e+07', 'completions/mean_length': '13.27', 'completions/min_length': '6.8', 'completions/max_length': '33.6', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '13.27', 'completions/min_terminated_length': '6.8', 'completions/max_terminated_length': '33.6', 'rewards/reward_function/mean': '0.07474', 'rewards/reward_function/std': '0.7445', 'reward': '0.07474', 'reward_std': '0.7445', 'frac_reward_zero_std': '0.6687', 'kl': '0.2475', 'entropy': '0.1231', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '21.84', 'epoch': '2.832'} +{'loss': '0.01506', 'grad_norm': '10.69', 'learning_rate': '5e-06', 'num_tokens': '5.873e+07', 'completions/mean_length': '13.45', 'completions/min_length': '7', 'completions/max_length': '26.6', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '13.45', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '26.6', 'rewards/reward_function/mean': '0.1487', 'rewards/reward_function/std': '0.6811', 'reward': '0.1487', 'reward_std': '0.6811', 'frac_reward_zero_std': '0.725', 'kl': '0.2482', 'entropy': '0.09926', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '21.61', 'epoch': '2.848'} +{'loss': '0.01854', 'grad_norm': '7.781', 'learning_rate': '5e-06', 'num_tokens': '5.908e+07', 'completions/mean_length': '13.43', 'completions/min_length': '7', 'completions/max_length': '30', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '13.43', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '30', 'rewards/reward_function/mean': '0.08005', 'rewards/reward_function/std': '0.7262', 'reward': '0.08005', 'reward_std': '0.7262', 'frac_reward_zero_std': '0.675', 'kl': '0.2487', 'entropy': '0.1304', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '22.24', 'epoch': '2.864'} +{'loss': '0.01233', 'grad_norm': '6.531', 'learning_rate': '5e-06', 'num_tokens': '5.941e+07', 'completions/mean_length': '13.07', 'completions/min_length': '7', 'completions/max_length': '29.2', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '13.07', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '29.2', 'rewards/reward_function/mean': '0.1002', 'rewards/reward_function/std': '0.6866', 'reward': '0.1002', 'reward_std': '0.6866', 'frac_reward_zero_std': '0.7063', 'kl': '0.2959', 'entropy': '0.1216', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '20.95', 'epoch': '2.88'} + 96%|█████████▌| 1801/1875 [11:38:58<47:53, 38.83s/it] 96%|█████████▌| 1802/1875 [11:39:24<42:17, 34.76s/it] 96%|█████████▌| 1803/1875 [11:39:43<36:06, 30.09s/it] 96%|█████████▌| 1804/1875 [11:40:01<31:20, 26.49s/it] 96%|█████████▋| 1805/1875 [11:40:24<29:50, 25.58s/it] 96%|█████████▋| 1806/1875 [11:40:43<26:52, 23.37s/it] 96%|█████████▋| 1807/1875 [11:41:02<25:07, 22.18s/it] 96%|█████████▋| 1808/1875 [11:41:20<23:29, 21.04s/it] 96%|█████████▋| 1809/1875 [11:41:41<23:10, 21.06s/it] 97%|█████████▋| 1810/1875 [11:42:06<23:58, 22.13s/it] 97%|█████████▋| 1810/1875 [11:42:06<23:58, 22.13s/it] 97%|█████████▋| 1811/1875 [11:42:23<21:51, 20.50s/it] 97%|█████████▋| 1812/1875 [11:42:45<21:58, 20.94s/it] 97%|█████████▋| 1813/1875 [11:43:16<24:45, 23.95s/it] 97%|█████████▋| 1814/1875 [11:43:33<22:23, 22.02s/it] 97%|█████████▋| 1815/1875 [11:44:15<27:48, 27.81s/it] 97%|█████████▋| 1816/1875 [11:44:35<25:05, 25.52s/it] 97%|█████████▋| 1817/1875 [11:45:00<24:36, 25.46s/it] 97%|█████████▋| 1818/1875 [11:45:23<23:28, 24.71s/it] 97%|█████████▋| 1819/1875 [11:45:44<21:57, 23.53s/it] 97%|█████████▋| 1820/1875 [11:46:05<20:58, 22.87s/it] 97%|█████████▋| 1820/1875 [11:46:05<20:58, 22.87s/it] 97%|█████████▋| 1821/1875 [11:46:30<21:05, 23.43s/it] 97%|█████████▋| 1822/1875 [11:46:48<19:21, 21.92s/it] 97%|█████████▋| 1823/1875 [11:47:11<19:18, 22.28s/it] 97%|█████████▋| 1824/1875 [11:47:34<19:00, 22.37s/it] 97%|█████████▋| 1825/1875 [11:47:58<18:56, 22.74s/it] 97%|█████████▋| 1826/1875 [11:48:18<18:02, 22.10s/it] 97%|█████████▋| 1827/1875 [11:48:38<17:14, 21.55s/it] 97%|█████████▋| 1828/1875 [11:49:00<16:47, 21.43s/it] 98%|█████████▊| 1829/1875 [11:49:17<15:28, 20.19s/it] 98%|█████████▊| 1830/1875 [11:49:39<15:37, 20.84s/it] 98%|█████████▊| 1830/1875 [11:49:39<15:37, 20.84s/it] 98%|█████████▊| 1831/1875 [11:50:04<16:03, 21.90s/it] 98%|█████████▊| 1832/1875 [11:50:22<15:01, 20.96s/it] 98%|█████████▊| 1833/1875 [11:50:43<14:36, 20.86s/it] 98%|█████████▊| 1834/1875 [11:51:04<14:13, 20.81s/it] 98%|█████████▊| 1835/1875 [11:51:24<13:52, 20.81s/it] 98%|█████████▊| 1836/1875 [11:51:46<13:42, 21.10s/it] 98%|█████████▊| 1837/1875 [11:52:04<12:45, 20.14s/it] 98%|█████████▊| 1838/1875 [11:52:24<12:22, 20.08s/it] 98%|█████████▊| 1839/1875 [11:52:46<12:19, 20.53s/it] 98%|█████████▊| 1840/1875 [11:53:25<15:16, 26.19s/it] 98%|█████████▊| 1840/1875 [11:53:25<15:16, 26.19s/it] 98%|█████████▊| 1841/1875 [11:53:44<13:35, 23.99s/it] 98%|█████████▊| 1842/1875 [11:54:05<12:43, 23.14s/it] 98%|█████████▊| 1843/1875 [11:54:28<12:15, 22.99s/it] 98%|█████████▊| 1844/1875 [11:54:50<11:50, 22.93s/it] 98%|█████████▊| 1845/1875 [11:55:10<10:58, 21.94s/it] 98%|█████████▊| 1846/1875 [11:55:31<10:27, 21.62s/it] 99%|█████████▊| 1847/1875 [11:55:52<10:03, 21.57s/it] 99%|█████████▊| 1848/1875 [11:56:12<09:25, 20.93s/it] 99%|█████████▊| 1849/1875 [11:56:33<09:03, 20.89s/it] 99%|█████████▊| 1850/1875 [11:57:06<10:16, 24.65s/it] 99%|█████████▊| 1850/1875 [11:57:06<10:16, 24.65s/it] 99%|█████████▊| 1851/1875 [11:57:28<09:33, 23.91s/it] 99%|█████████▉| 1852/1875 [11:57:46<08:28, 22.11s/it] 99%|█████████▉| 1853/1875 [11:58:10<08:17, 22.60s/it] 99%|█████████▉| 1854/1875 [11:58:31<07:44, 22.10s/it] 99%|█████████▉| 1855/1875 [11:58:51<07:12, 21.63s/it] 99%|█████████▉| 1856/1875 [11:59:10<06:33, 20.74s/it] 99%|█████████▉| 1857/1875 [11:59:28<05:58, 19.91s/it] 99%|█████████▉| 1858/1875 [11:59:51<05:54, 20.86s/it] 99%|█████████▉| 1859/1875 [12:00:12<05:34, 20.90s/it] 99%|█████████▉| 1860/1875 [12:00:37<05:29, 21.96s/it] 99%|█████████▉| 1860/1875 [12:00:37<05:29, 21.96s/it] 99%|█████████▉| 1861/1875 [12:00:57<05:00, 21.43s/it] 99%|█████████▉| 1862/1875 [12:01:21<04:48, 22.16s/it] 99%|█████████▉| 1863/1875 [12:01:43<04:25, 22.11s/it] 99%|█████████▉| 1864/1875 [12:02:11<04:22, 23.91s/it] 99%|█████████▉| 1865/1875 [12:02:31<03:49, 22.91s/it] 100%|█████████▉| 1866/1875 [12:02:53<03:24, 22.69s/it] 100%|█████████▉| 1867/1875 [12:03:16<03:00, 22.53s/it] 100%|█████████▉| 1868/1875 [12:03:41<02:42, 23.25s/it] 100%|█████████▉| 1869/1875 [12:04:04<02:20, 23.36s/it] 100%|█████████▉| 1870/1875 [12:04:30<02:00, 24.16s/it] 100%|█████████▉| 1870/1875 [12:04:30<02:00, 24.16s/it] 100%|█████████▉| 1871/1875 [12:04:48<01:29, 22.40s/it] 100%|█████████▉| 1872/1875 [12:05:15<01:10, 23.57s/it] 100%|█████████▉| 1873/1875 [12:05:38<00:47, 23.52s/it] 100%|█████████▉| 1874/1875 [12:05:56<00:21, 21.70s/it] 100%|██████████| 1875/1875 [12:06:13<00:00, 20.44s/it]{'loss': '0.02795', 'grad_norm': '8.125', 'learning_rate': '5e-06', 'num_tokens': '5.973e+07', 'completions/mean_length': '12.86', 'completions/min_length': '7', 'completions/max_length': '33.3', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '12.86', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '33.3', 'rewards/reward_function/mean': '0.04785', 'rewards/reward_function/std': '0.6747', 'reward': '0.04785', 'reward_std': '0.6747', 'frac_reward_zero_std': '0.7', 'kl': '0.3015', 'entropy': '0.1293', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '20.39', 'epoch': '2.896'} +{'loss': '0.06095', 'grad_norm': '4.156', 'learning_rate': '5e-06', 'num_tokens': '6.005e+07', 'completions/mean_length': '13.86', 'completions/min_length': '7', 'completions/max_length': '62.8', 'completions/clipped_ratio': '0.001563', 'completions/mean_terminated_length': '13.49', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '40.7', 'rewards/reward_function/mean': '0.135', 'rewards/reward_function/std': '0.6865', 'reward': '0.135', 'reward_std': '0.6865', 'frac_reward_zero_std': '0.675', 'kl': '0.2591', 'entropy': '0.3537', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '23.82', 'epoch': '2.912'} +{'loss': '0.04141', 'grad_norm': '11.38', 'learning_rate': '5e-06', 'num_tokens': '6.038e+07', 'completions/mean_length': '13.85', 'completions/min_length': '7', 'completions/max_length': '37.7', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '13.85', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '37.7', 'rewards/reward_function/mean': '0.09052', 'rewards/reward_function/std': '0.7027', 'reward': '0.09052', 'reward_std': '0.7027', 'frac_reward_zero_std': '0.6375', 'kl': '0.252', 'entropy': '0.1284', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '21.32', 'epoch': '2.928'} +{'loss': '0.01553', 'grad_norm': '11.12', 'learning_rate': '5e-06', 'num_tokens': '6.068e+07', 'completions/mean_length': '13.55', 'completions/min_length': '7', 'completions/max_length': '61.7', 'completions/clipped_ratio': '0.001563', 'completions/mean_terminated_length': '13.17', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '39.3', 'rewards/reward_function/mean': '0.09312', 'rewards/reward_function/std': '0.7026', 'reward': '0.09312', 'reward_std': '0.7026', 'frac_reward_zero_std': '0.7312', 'kl': '0.296', 'entropy': '0.401', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '22.5', 'epoch': '2.944'} +{'loss': '0.02066', 'grad_norm': '6.188', 'learning_rate': '5e-06', 'num_tokens': '6.101e+07', 'completions/mean_length': '13.95', 'completions/min_length': '7', 'completions/max_length': '45.6', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '13.95', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '45.6', 'rewards/reward_function/mean': '0.1069', 'rewards/reward_function/std': '0.6999', 'reward': '0.1069', 'reward_std': '0.6999', 'frac_reward_zero_std': '0.6813', 'kl': '0.2669', 'entropy': '0.2489', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '22.02', 'epoch': '2.96'} +{'loss': '0.0361', 'grad_norm': '9.125', 'learning_rate': '5e-06', 'num_tokens': '6.134e+07', 'completions/mean_length': '13.27', 'completions/min_length': '7', 'completions/max_length': '31', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '13.27', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '31', 'rewards/reward_function/mean': '0.05288', 'rewards/reward_function/std': '0.7062', 'reward': '0.05288', 'reward_std': '0.7062', 'frac_reward_zero_std': '0.6625', 'kl': '0.265', 'entropy': '0.1328', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '20.95', 'epoch': '2.976'} +{'loss': '0.0004211', 'grad_norm': '5.531', 'learning_rate': '5e-06', 'num_tokens': '6.17e+07', 'completions/mean_length': '13.49', 'completions/min_length': '7', 'completions/max_length': '38.4', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '13.49', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '38.4', 'rewards/reward_function/mean': '0.1066', 'rewards/reward_function/std': '0.7293', 'reward': '0.1066', 'reward_std': '0.7293', 'frac_reward_zero_std': '0.7375', 'kl': '0.2522', 'entropy': '0.1934', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '23.26', 'epoch': '2.992'} +{'train_runtime': '4.364e+04', 'train_samples_per_second': '0.688', 'train_steps_per_second': '0.043', 'train_loss': '0.02553', 'num_tokens': '6.185e+07', 'completions/mean_length': '11.94', 'completions/min_length': '7', 'completions/max_length': '31.8', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '11.94', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '31.8', 'rewards/reward_function/mean': '0.01653', 'rewards/reward_function/std': '0.723', 'reward': '0.01653', 'reward_std': '0.723', 'frac_reward_zero_std': '0.7375', 'kl': '0.2654', 'entropy': '0.1318', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '20.51', 'epoch': '3'} diff --git a/ICL/RL/plan.md b/ICL/RL/plan.md new file mode 100644 index 0000000000000000000000000000000000000000..a078dcb7cd4a3ddd5abc8701e950ec2cbdafa415 --- /dev/null +++ b/ICL/RL/plan.md @@ -0,0 +1,167 @@ +# GRPO 训练方案:基于 M3IT 与 SigLIP 的模型驱动迭代检索 + +## 第一部分:数据战略 (The Data Curriculum) + +为了训练模型"知进退"(Know when to fold, when to hold),我们必须构建一个 **1:1 的对抗数据池**。 + +### 1. 数据来源:M3IT 数据集 +你需要把数据分为两类,混合打乱后用于 RL 训练。 + +* **Positive Set (必须检索类) —— 占比 50%** + * **任务定义:** 光看像素无法回答,需要外部知识或类似案例。 + * **精选子集:** + * **OK-VQA / A-OKVQA:** 知识密集型问答。 + * **ScienceQA:** 科学常识推断。 + * **ViQuAE (如果有):** 实体知识检索。 + * **预期行为:** 模型应输出 ``,且描述越准越好。 + +* **Negative Set (禁止/少检索类) —— 占比 50%** + * **任务定义:** 答案就在图里,检索纯属浪费时间,甚至引入噪音。 + * **精选子集:** + * **TextVQA / OCR-VQA:** 看图读字。检索相似图通常会导致读错字。 + * **VQA-v2:** 基础视觉识别(颜色、数量、位置)。 + * **CLEVR:** 纯逻辑推理。 + * **预期行为:** 模型应直接 ``,或者在尝试 1 次检索后发现没用立刻 ``。 + +--- + +## 第二部分:交互环境 (The Rollout Pipeline) + +这是 GRPO 采样时发生的 **动态过程**。每次 rollout 都是一次独立的探索。 + +### 核心逻辑:While 循环 + 动态 Top-1 注入 + +1. **初始状态:** 输入 `[User Image] + [Question]`。 +2. **模型生成 (Action):** + * **路径 A (直答):** 模型输出 ` 答案` $\rightarrow$ **结束**。 + * **路径 B (检索):** 模型输出 ` 描述文本` $\rightarrow$ **暂停**。 +3. **环境反馈 (Environment):** + * 截取 `描述文本`。 + * 使用 SigLIP 计算该文本与 **候选图库 (Candidate Pool)** 的相似度。 + * **只取 Top-1** 最相似的图片及其 Caption/Answer。 + * **拼接历史:** 将 `...` 和 `System: Retrieved Image...` 加入对话历史。 +4. **循环判断:** + * 将新的历史喂回模型,进入下一轮。 + * **最大轮次限制 (Max Turns):** 设为 3。如果到了 3 轮还没 ``,强制截断并给予重罚。 + +### 候选图库 (Candidate Pool) 构建策略 + +**✅ 最终方案:Sub-dataset Pool (分任务池)** + +* **范围:** 当训练 ok-vqa 的样本时,检索池只包含 ok-vqa 的所有图片。 +* **原因:** 全局 M3IT 池太大(百万级),检索慢且干扰多;分任务池既快又准。 +* **Self-Exclusion(必须):** 检索时必须排除 Query Image 本身。 + * 如果不排除,模型只要学会描述原图,SigLIP 搜原图的相似度永远是 1.0。 + * 模型会学会"抄袭原图"而不是"寻找相似案例"。 + * **实现:** 在检索时,如果 Top-1 的 Image ID == Query Image ID,则顺延取 Top-2。 + +--- + +## 第三部分:奖励函数设计 (The Robust Reward) - **已根据数据分析修正** + +### ⚠️ 重要修正:移除 R_faith + +**数据分析结果(500样本):** +- Image-to-Caption 相似度:Mean = 0.0732, Max = 0.1797 +- **结论:** SigLIP 的文本编码器对短答案的编码效果极差,分数全是噪声 +- **R_faith 已被证明无效,完全移除** + +### 最终奖励公式 + +$$R_{total} = R_{outcome} + \text{Gate}(IsCorrect) \times R_{rel} + R_{penalty}$$ + +### 1. $R_{outcome}$:结果硬指标 (The Gatekeeper) +* **逻辑:** 无论过程多花哨,答错一律没分。 +* **计算:** + * **Answer Correct (F1/Exact Match):** **+1.0** + * **Answer Wrong:** **-1.0** + +### 2. $R_{penalty}$:阶梯式步数惩罚 (Efficiency) +* **逻辑:** 检索是有成本的。用的 Shot 越多,扣分越狠(非线性)。 +* **计算:** + * **0-shot (直答):** $0.0$ + * **1-shot:** $-0.1$ + * **2-shot:** $-0.25$ + * **3-shot:** $-0.5$ + * **达到 max_turns 仍未 :** $-2.0$ (死循环惩罚) +* **作用:** 如果模型能用 1-shot 答对 (1.0 - 0.1 + R_rel),它绝不会去用 3-shot (1.0 - 0.5 + R_rel)。 + +### 3. $R_{rel}$:检索有效性 (Relevance) - **唯一的视觉过程分** + +* **问题:** 确保检索回来的图跟原图是相似的(证明描述有效)。 +* **计算:** `scale_siglip_score(SigLIP_Score(原图, 检索到的 Top-1 图))` +* **归一化函数:** + ```python + def scale_siglip_score(raw_score, min_th=0.33, max_th=0.97): + """ + 基于数据分析结果: + - min_th = 0.33 (P05 of same-subdir pairs) + - max_th = 0.97 (P95 of same-subdir pairs) + """ + scaled = (raw_score - min_th) / (max_th - min_th) + return max(0.0, min(1.0, scaled)) + ``` +* **门控机制:** 只有当 $R_{outcome} > 0$ (答对了) 时,$R_{rel}$ 才生效! + +### 4. 强制截断处理 + +* **触发条件:** 模型连续输出 3 次 `` 仍未 `` +* **惩罚值:** $R_{penalty} = -2.0$ +* **逻辑:** 答错是 -1.0,死循环不仅没答对,还浪费了 3 倍算力,必须比直接答错惩罚更重 +* **实现:** 训练时直接截断 Trajectory,标记 Reward = -2.0,不强制输出 `` + +--- + +## 第四部分:GRPO 训练推演 (Case Study) - **基于修正后的奖励** + +在这个方案下,模型是如何进化的?我们模拟一个 Batch (Group Size = 8) 的竞争情况: + +* **场景:** 一道中等难度的题(需要 1 张参考图)。 + +### 样本对比 + +| 样本 | 行为 | R_outcome | R_penalty | R_rel | 总分 | +|------|------|-----------|-----------|-------|------| +| 1. 直答错 | 没检索,答错 | -1.0 | 0.0 | 0.0 | **-1.0** | +| 2. 1-shot完美 | 描述准→搜到好图→答对 | +1.0 | -0.1 | +1.0 | **1.9** | +| 3. 1-shot一般 | 搜到中等图→答对 | +1.0 | -0.1 | +0.5 | **1.4** | +| 4. 1-shot烂图 | 搜到烂图→蒙对 | +1.0 | -0.1 | +0.0 | **0.9** | +| 5. 3-shot完美 | 搜了3次→答对 | +1.0 | -0.5 | +1.0 | **1.5** | +| 6. 直答对 | 没检索,答对 | +1.0 | 0.0 | 0.0 | **1.0** | +| 7. 死循环 | 3次RET未ANS | -1.0 | -2.0 | 0.0 | **-3.0** | + +### GRPO 的梯度更新方向 + +模型会发现 **样本 2 (1.9分)** 是绝对的赢家: +1. 它比样本 1 强:说明**该搜的时候必须搜** +2. 它比样本 5 强:说明**能搜 1 次别搜 3 次**(效率优先) +3. 它比样本 4 强:说明**描述必须写得准**(R_rel 高才能拿高分) +4. 它比样本 6 强:说明**对于需要检索的题,检索比瞎猜好** + +--- + +## 第五部分:实现细节 + +### 训练配置 +* **模型:** Qwen3-VL-8B (从 SFT checkpoint 开始) +* **SFT Checkpoint:** `/workspace/xiaobin/SFT_model/hf_qwen3vl_siglip_vqa_iter_0000881` +* **硬件:** 8×H100 (80GB) +* **Group Size:** 8 (每个query采样8条轨迹) +* **框架:** Megatron-BPLM +* **并行策略:** TP=8, DP=1 + +### 数据分布验证 +* **SigLIP Image-to-Image (Same Subdir):** + * Mean: 0.5188, Std: 0.1689 + * P05: 0.331, P95: 0.969 + * ✅ 区分度良好,适合作为奖励信号 + +### 总结执行清单 + +1. **数据清洗:** 将 M3IT 按 50:50 拆分为 Positive (OK-VQA等) 和 Negative (TextVQA等) +2. **Pipeline 实现:** 写好 `while` 循环,每次只注入 SigLIP 找出的 Top-1(排除原图) +3. **Reward 实现:** + - 移除 R_faith + - 使用 R_rel (min_th=0.33, max_th=0.97) + - 严格执行门控机制和非线性步数惩罚 +4. **启动:** 这是一个自我博弈的过程,模型从"胡乱检索"逐渐收敛到"精准猎杀" diff --git a/ICL/RL/plot_metrics.py b/ICL/RL/plot_metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..85842d8111b89c69480a0e3341cce4e8b2e112ab --- /dev/null +++ b/ICL/RL/plot_metrics.py @@ -0,0 +1,127 @@ +#!/usr/bin/env python3 +"""Parse GRPO training log and plot key metrics.""" + +import re +import ast +import os +import matplotlib +matplotlib.use('Agg') +import matplotlib.pyplot as plt +import numpy as np + +LOG_FILE = "/workspace/xiaobin/RL/train_grpo_20260224_133510.log" +SAVE_DIR = "/workspace/xiaobin/RL/plots" +os.makedirs(SAVE_DIR, exist_ok=True) + +# Parse all metric lines +records = [] +with open(LOG_FILE, "r") as f: + for line in f: + line = line.strip() + if line.startswith("{'loss':") and "'epoch'" in line: + # skip the final summary line (has 'train_runtime') + if "'train_runtime'" in line: + continue + try: + d = ast.literal_eval(line) + # convert all values to float + d = {k: float(v) for k, v in d.items()} + records.append(d) + except Exception: + pass + +print(f"Parsed {len(records)} training steps") + +if not records: + print("No records found!") + exit(1) + +# Extract steps (1-indexed) and epochs +steps = list(range(1, len(records) + 1)) +epochs = [r['epoch'] for r in records] + +# Define which metrics to plot, grouped logically +plot_configs = [ + # (filename, title, metrics_list, ylabel) + ("loss.png", "Training Loss", [("loss", "Loss")], "Loss"), + ("grad_norm.png", "Gradient Norm", [("grad_norm", "Grad Norm")], "Grad Norm"), + ("learning_rate.png", "Learning Rate", [("learning_rate", "LR")], "Learning Rate"), + ("reward.png", "Reward (mean & std)", [("reward", "Reward Mean"), ("reward_std", "Reward Std")], "Reward"), + ("reward_detail.png", "Reward Function Detail", + [("rewards/reward_function/mean", "Mean"), ("rewards/reward_function/std", "Std")], "Reward"), + ("kl_divergence.png", "KL Divergence", [("kl", "KL")], "KL"), + ("entropy.png", "Entropy", [("entropy", "Entropy")], "Entropy"), + ("completion_length.png", "Completion Length", + [("completions/mean_length", "Mean"), ("completions/min_length", "Min"), ("completions/max_length", "Max")], + "Token Length"), + ("completion_terminated_length.png", "Terminated Completion Length", + [("completions/mean_terminated_length", "Mean"), ("completions/min_terminated_length", "Min"), + ("completions/max_terminated_length", "Max")], "Token Length"), + ("clipped_ratio.png", "Completions Clipped Ratio", [("completions/clipped_ratio", "Clipped Ratio")], "Ratio"), + ("frac_reward_zero_std.png", "Fraction Reward Zero Std", [("frac_reward_zero_std", "Frac Zero Std")], "Fraction"), + ("clip_ratio_high.png", "Clip Ratio (High)", + [("clip_ratio/high_mean", "High Mean"), ("clip_ratio/high_max", "High Max")], "Ratio"), + ("clip_ratio_low.png", "Clip Ratio (Low)", + [("clip_ratio/low_mean", "Low Mean"), ("clip_ratio/low_min", "Low Min")], "Ratio"), + ("step_time.png", "Step Time", [("step_time", "Step Time (s)")], "Seconds"), +] + +# Smoothing helper (simple moving average) +def smooth(y, window=21): + if len(y) < window: + return y + cumsum = np.cumsum(np.insert(y, 0, 0)) + return (cumsum[window:] - cumsum[:-window]) / window + +# Plot each config +for fname, title, metrics, ylabel in plot_configs: + fig, ax = plt.subplots(figsize=(12, 5)) + for key, label in metrics: + vals = [r.get(key, float('nan')) for r in records] + ax.plot(steps, vals, alpha=0.3, linewidth=0.8) + # add smoothed line + s = smooth(np.array(vals)) + offset = (len(vals) - len(s)) // 2 + ax.plot(steps[offset:offset+len(s)], s, linewidth=2, label=label + " (smoothed)") + ax.set_xlabel("Step") + ax.set_ylabel(ylabel) + ax.set_title(title) + ax.legend() + ax.grid(True, alpha=0.3) + plt.tight_layout() + path = os.path.join(SAVE_DIR, fname) + fig.savefig(path, dpi=150) + plt.close(fig) + print(f"Saved: {path}") + +# Also make a combined overview figure +fig, axes = plt.subplots(3, 3, figsize=(20, 14)) +overview_metrics = [ + ("loss", "Loss"), + ("reward", "Reward"), + ("kl", "KL Divergence"), + ("entropy", "Entropy"), + ("grad_norm", "Grad Norm"), + ("completions/mean_length", "Completion Mean Length"), + ("frac_reward_zero_std", "Frac Reward Zero Std"), + ("completions/clipped_ratio", "Clipped Ratio"), + ("step_time", "Step Time (s)"), +] +for ax, (key, title) in zip(axes.flat, overview_metrics): + vals = [r.get(key, float('nan')) for r in records] + ax.plot(steps, vals, alpha=0.3, linewidth=0.8, color='steelblue') + s = smooth(np.array(vals)) + offset = (len(vals) - len(s)) // 2 + ax.plot(steps[offset:offset+len(s)], s, linewidth=2, color='orangered') + ax.set_title(title, fontsize=12) + ax.set_xlabel("Step", fontsize=9) + ax.grid(True, alpha=0.3) + +fig.suptitle("GRPO Training Overview", fontsize=16, fontweight='bold') +plt.tight_layout(rect=[0, 0, 1, 0.96]) +path = os.path.join(SAVE_DIR, "overview.png") +fig.savefig(path, dpi=150) +plt.close(fig) +print(f"Saved: {path}") + +print(f"\nAll plots saved to: {SAVE_DIR}") diff --git a/ICL/RL/plots/clip_ratio_high.png b/ICL/RL/plots/clip_ratio_high.png new file mode 100644 index 0000000000000000000000000000000000000000..f9d4ce1eaac66d348f7a9eb4179cfbeb53ea38cd Binary files /dev/null and b/ICL/RL/plots/clip_ratio_high.png differ diff --git a/ICL/RL/plots/clip_ratio_low.png b/ICL/RL/plots/clip_ratio_low.png new file mode 100644 index 0000000000000000000000000000000000000000..74b07493679603e9da0b80b7964d0fec8947b74c Binary files /dev/null and b/ICL/RL/plots/clip_ratio_low.png differ diff --git a/ICL/RL/plots/clipped_ratio.png b/ICL/RL/plots/clipped_ratio.png new file mode 100644 index 0000000000000000000000000000000000000000..905f6aa119e62f4b8d2c017ee20408c6e77aba7d Binary files /dev/null and b/ICL/RL/plots/clipped_ratio.png differ diff --git a/ICL/RL/plots/completion_length.png b/ICL/RL/plots/completion_length.png new file mode 100644 index 0000000000000000000000000000000000000000..e47b2c8a0fc71a5a5c7bb44208257fba65ef15d1 Binary files /dev/null and b/ICL/RL/plots/completion_length.png differ diff --git a/ICL/RL/plots/entropy.png b/ICL/RL/plots/entropy.png new file mode 100644 index 0000000000000000000000000000000000000000..30aba9cb0d384d1abc0ad131a5881c9f946b651e Binary files /dev/null and b/ICL/RL/plots/entropy.png differ diff --git a/ICL/RL/plots/frac_reward_zero_std.png b/ICL/RL/plots/frac_reward_zero_std.png new file mode 100644 index 0000000000000000000000000000000000000000..60fc17e29e8ea6c54c71195b89670a063ccc7a37 Binary files /dev/null and b/ICL/RL/plots/frac_reward_zero_std.png differ diff --git a/ICL/RL/plots/grad_norm.png b/ICL/RL/plots/grad_norm.png new file mode 100644 index 0000000000000000000000000000000000000000..a398b9bfe27a313de0be56cb97da173c766d43d0 Binary files /dev/null and b/ICL/RL/plots/grad_norm.png differ diff --git a/ICL/RL/plots/learning_rate.png b/ICL/RL/plots/learning_rate.png new file mode 100644 index 0000000000000000000000000000000000000000..445036d9a455864d451e41d248f652a335af18f9 Binary files /dev/null and b/ICL/RL/plots/learning_rate.png differ diff --git a/ICL/RL/quickstart.sh b/ICL/RL/quickstart.sh new file mode 100644 index 0000000000000000000000000000000000000000..d404dd245a3505f219895776331902eab80beb01 --- /dev/null +++ b/ICL/RL/quickstart.sh @@ -0,0 +1,36 @@ +#!/bin/bash +# Quick start script for GRPO training + +echo "==================================" +echo "GRPO Training Quick Start" +echo "==================================" +echo "" + +# Check if virtual environment exists +if [ ! -d "venv" ]; then + echo "Creating virtual environment..." + python3 -m venv venv +fi + +# Activate virtual environment +echo "Activating virtual environment..." +source venv/bin/activate + +# Install dependencies +echo "Installing dependencies..." +pip install -r requirements.txt + +echo "" +echo "==================================" +echo "Setup complete!" +echo "==================================" +echo "" +echo "Next steps:" +echo " 1. Test reward functions: python test_rewards.py" +echo " 2. Visualize rewards: python visualize_rewards.py" +echo " 3. Start training: python train_grpo.py" +echo "" +echo "To monitor training:" +echo " - Check wandb dashboard (if enabled)" +echo " - View logs in output/ directory" +echo "" diff --git a/ICL/RL/requirements.txt b/ICL/RL/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..e555414a76052aef015369ad0b69cb7260c5f7c3 --- /dev/null +++ b/ICL/RL/requirements.txt @@ -0,0 +1,16 @@ +torch>=2.0.0 +transformers>=4.40.0 +trl>=0.12.0 +datasets>=2.14.0 +accelerate>=0.27.0 +peft>=0.10.0 +bitsandbytes>=0.43.0 +sentencepiece>=0.2.0 +protobuf>=3.20.0 +pillow>=10.0.0 +numpy>=1.24.0 +scikit-learn>=1.3.0 +tqdm>=4.65.0 +wandb>=0.15.0 +matplotlib>=3.7.0 +pyyaml>=6.0 diff --git a/ICL/RL/reward_functions.py b/ICL/RL/reward_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..6ad715c0268b0a86676d2443a47206a1e83b6a41 --- /dev/null +++ b/ICL/RL/reward_functions.py @@ -0,0 +1,135 @@ +""" +Reward functions for GRPO training. +Based on the plan: R_total = R_outcome + Gate(IsCorrect) × R_rel + R_penalty +""" +import re +from typing import List, Dict, Any +import torch + + +def scale_siglip_score(raw_score: float, min_th: float = 0.33, max_th: float = 0.97) -> float: + """ + Scale SigLIP similarity score to [0, 1] range. + Based on data analysis: P05=0.33, P95=0.97 for same-subdir pairs. + """ + scaled = (raw_score - min_th) / (max_th - min_th) + return max(0.0, min(1.0, scaled)) + + +def compute_f1_score(prediction: str, ground_truth: str) -> float: + """Compute F1 score between prediction and ground truth.""" + pred_tokens = set(prediction.lower().split()) + gt_tokens = set(ground_truth.lower().split()) + + if len(pred_tokens) == 0 or len(gt_tokens) == 0: + return 0.0 + + common = pred_tokens & gt_tokens + if len(common) == 0: + return 0.0 + + precision = len(common) / len(pred_tokens) + recall = len(common) / len(gt_tokens) + f1 = 2 * precision * recall / (precision + recall) + return f1 + + +def extract_answer(text: str) -> str: + """Extract answer from model output after token.""" + if "" in text: + return text.split("")[-1].strip() + return text.strip() + + +def count_retrieval_steps(trajectory: str) -> int: + """Count number of tokens in trajectory.""" + return trajectory.count("") + + +def compute_reward( + trajectory: str, + ground_truth: str, + siglip_scores: List[float], + max_turns: int = 3, + timeout: bool = False +) -> Dict[str, float]: + """ + Compute total reward for a trajectory. + + Args: + trajectory: Full model output with and tokens + ground_truth: Ground truth answer + siglip_scores: List of SigLIP similarity scores for each retrieval + max_turns: Maximum allowed retrieval turns + timeout: Whether trajectory hit max turns without + + Returns: + Dictionary with reward components and total + """ + # Extract final answer + prediction = extract_answer(trajectory) + + # Count retrieval steps + num_steps = count_retrieval_steps(trajectory) + + # R_outcome: continuous outcome reward (dense signal) + # Map f1 ∈ [0, 1] to r_outcome ∈ [-1, 1] + f1 = compute_f1_score(prediction, ground_truth) + r_outcome = 2.0 * f1 - 1.0 + + # Keep a softer correctness flag for logging only (no longer gates reward) + is_correct = f1 > 0.3 + + # R_penalty: Step-based penalty (non-linear) + if timeout: + # Death penalty for infinite loop (softened to encourage exploration early) + r_penalty = -1.5 + else: + # Lighter penalty encourages exploration/retrieval early + penalty_map = {0: 0.0, 1: -0.05, 2: -0.15, 3: -0.3} + r_penalty = penalty_map.get(num_steps, -0.3) + + # R_rel: Retrieval relevance (decoupled; always provides signal if retrieval happened) + r_rel = 0.0 + if len(siglip_scores) > 0: + # Average scaled SigLIP scores across all retrievals + scaled_scores = [scale_siglip_score(score) for score in siglip_scores] + avg_rel = sum(scaled_scores) / len(scaled_scores) + # Soft-gate by f1 but keep a minimum signal; downweight to avoid dominating r_outcome + r_rel = avg_rel * max(0.3, f1) * 0.5 + + # Total reward + r_total = r_outcome + r_rel + r_penalty + + return { + "r_outcome": r_outcome, + "r_penalty": r_penalty, + "r_rel": r_rel, + "r_total": r_total, + "f1_score": f1, + "num_steps": num_steps, + "is_correct": is_correct, + "timeout": timeout + } + + +def batch_compute_rewards( + trajectories: List[str], + ground_truths: List[str], + siglip_scores_list: List[List[float]], + max_turns: int = 3, + timeouts: List[bool] = None +) -> List[float]: + """ + Compute rewards for a batch of trajectories. + Returns list of total rewards for GRPO. + """ + if timeouts is None: + timeouts = [False] * len(trajectories) + + rewards = [] + for traj, gt, scores, timeout in zip(trajectories, ground_truths, siglip_scores_list, timeouts): + reward_dict = compute_reward(traj, gt, scores, max_turns, timeout) + rewards.append(reward_dict["r_total"]) + + return rewards diff --git a/ICL/RL/run_grpo.sh b/ICL/RL/run_grpo.sh new file mode 100644 index 0000000000000000000000000000000000000000..d13790d5af98d44065e9f0a898a262b4a8d8feb7 --- /dev/null +++ b/ICL/RL/run_grpo.sh @@ -0,0 +1,29 @@ +#!/bin/bash +cd /workspace/xiaobin/RL + +LOG_FILE="train_grpo_$(date +%Y%m%d_%H%M%S).log" +KEY_LOG="key_metrics_$(date +%Y%m%d_%H%M%S).log" + +echo "Full log: $LOG_FILE" +echo "Key metrics log: $KEY_LOG" + +# Run training, full output to LOG_FILE +nohup python train_grpo.py > "$LOG_FILE" 2>&1 & +PID=$! +echo "Training started with PID: $PID" +echo "$PID" > train_pid.txt + +# Background process to extract key metrics every 30s +nohup bash -c " +while kill -0 $PID 2>/dev/null; do + grep -E '(loss|reward|entropy|clip_ratio|grad_norm|epoch|frac_reward)' \"$LOG_FILE\" | tail -200 > \"$KEY_LOG\" + sleep 30 +done +# Final extraction +grep -E '(loss|reward|entropy|clip_ratio|grad_norm|epoch|frac_reward)' \"$LOG_FILE\" > \"$KEY_LOG\" +echo 'Training finished.' >> \"$KEY_LOG\" +" > /dev/null 2>&1 & + +echo "Metric watcher started." +echo "To monitor: tail -f /workspace/xiaobin/RL/$LOG_FILE" +echo "Key metrics: cat /workspace/xiaobin/RL/$KEY_LOG" diff --git a/ICL/RL/siglip_analysis/score_distribution.png b/ICL/RL/siglip_analysis/score_distribution.png new file mode 100644 index 0000000000000000000000000000000000000000..30a786eccbd89fc2c8572a5c9403eadc436cd66c Binary files /dev/null and b/ICL/RL/siglip_analysis/score_distribution.png differ diff --git a/ICL/RL/test_device_config.py b/ICL/RL/test_device_config.py new file mode 100644 index 0000000000000000000000000000000000000000..2e20dccd2f527ca80233869ff424b6b810f31177 --- /dev/null +++ b/ICL/RL/test_device_config.py @@ -0,0 +1,111 @@ +""" +Test script to verify multi-GPU device configuration is correct. +""" +import torch +from transformers import AutoModel, AutoProcessor + +def test_device_config(): + print("="*80) + print("TESTING DEVICE CONFIGURATION") + print("="*80) + + # Check CUDA availability + print(f"\nCUDA available: {torch.cuda.is_available()}") + print(f"GPU count: {torch.cuda.device_count()}") + + if torch.cuda.device_count() > 0: + for i in range(torch.cuda.device_count()): + print(f"GPU {i}: {torch.cuda.get_device_name(i)}") + + # Load SigLIP with device_map="auto" + print("\n" + "="*80) + print("LOADING SIGLIP MODEL") + print("="*80) + + siglip_model_name = "/workspace/siglip2-so400m-patch16-naflex" + siglip_model = AutoModel.from_pretrained(siglip_model_name, device_map="auto") + siglip_processor = AutoProcessor.from_pretrained(siglip_model_name) + + # Check where model parameters are + print("\nSigLIP model parameter devices:") + for name, param in list(siglip_model.named_parameters())[:5]: + print(f" {name}: {param.device}") + + # Get the device of the first parameter + model_device = next(siglip_model.parameters()).device + print(f"\nFirst parameter device: {model_device}") + + # Test text encoding + print("\n" + "="*80) + print("TESTING TEXT ENCODING") + print("="*80) + + text_inputs = siglip_processor( + text=["a photo of a cat"], + return_tensors="pt", + padding=True + ) + + print(f"Text inputs device before moving: {text_inputs['input_ids'].device}") + + # Move to model device + text_inputs = {k: v.to(model_device) for k, v in text_inputs.items()} + print(f"Text inputs device after moving: {text_inputs['input_ids'].device}") + + with torch.no_grad(): + text_embeds = siglip_model.get_text_features(**text_inputs) + print(f"Text embeddings device: {text_embeds.device}") + print(f"Text embeddings shape: {text_embeds.shape}") + + # Test image encoding + print("\n" + "="*80) + print("TESTING IMAGE ENCODING") + print("="*80) + + from PIL import Image + import numpy as np + + # Create a dummy image + dummy_image = Image.fromarray(np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)) + + image_inputs = siglip_processor( + images=[dummy_image], + return_tensors="pt", + padding=True + ) + + print(f"Image inputs device before moving: {image_inputs['pixel_values'].device}") + + # Move to model device + image_inputs = {k: v.to(model_device) for k, v in image_inputs.items()} + print(f"Image inputs device after moving: {image_inputs['pixel_values'].device}") + + with torch.no_grad(): + image_embeds = siglip_model.get_image_features(**image_inputs) + print(f"Image embeddings device: {image_embeds.device}") + print(f"Image embeddings shape: {image_embeds.shape}") + + # Test similarity computation + print("\n" + "="*80) + print("TESTING SIMILARITY COMPUTATION") + print("="*80) + + # Create dummy embeddings on CPU + dummy_pool_embeds = torch.randn(100, text_embeds.shape[-1]) + print(f"Pool embeddings device (CPU): {dummy_pool_embeds.device}") + + # Move to same device as text_embeds + dummy_pool_embeds = dummy_pool_embeds.to(text_embeds.device) + print(f"Pool embeddings device after moving: {dummy_pool_embeds.device}") + + # Compute similarity + similarities = (text_embeds @ dummy_pool_embeds.T).squeeze(0) + print(f"Similarities device: {similarities.device}") + print(f"Similarities shape: {similarities.shape}") + + print("\n" + "="*80) + print("ALL TESTS PASSED!") + print("="*80) + +if __name__ == "__main__": + test_device_config() diff --git a/ICL/RL/train_grpo.py b/ICL/RL/train_grpo.py new file mode 100644 index 0000000000000000000000000000000000000000..afb50779d84385258d36ddd3b2f24ea596efbff5 --- /dev/null +++ b/ICL/RL/train_grpo.py @@ -0,0 +1,389 @@ +""" +Main GRPO training script for retrieval-augmented VQA. +Based on TRL's GRPOTrainer with custom reward functions and environment. +""" +import os +os.environ["PYTHONIOENCODING"] = "utf-8" +os.environ["LC_ALL"] = "C.UTF-8" +import torch +from transformers import AutoTokenizer, AutoProcessor +from trl import GRPOTrainer, GRPOConfig +from datasets import Dataset +import wandb +from typing import List, Dict, Any + +from reward_functions import batch_compute_rewards +from environment import RetrievalEnvironment +from data_utils import ( + load_m3it_splits, + create_balanced_dataset, + build_candidate_pools, + format_prompt +) + + +def create_reward_function(environment: RetrievalEnvironment): + """ + Create a reward function closure that captures the environment. + This function will be passed to GRPOTrainer as reward_funcs. + """ + def parse_retrieval_queries(completion: str) -> List[str]: + """ + Parse retrieval queries from completion text. + Extracts text between and the next or . + """ + queries = [] + parts = completion.split("") + + for i in range(1, len(parts)): # Skip first part (before first ) + query_part = parts[i] + # Extract until next special token or end + if "" in query_part: + query = query_part.split("")[0].strip() + elif "" in query_part: + query = query_part.split("")[0].strip() + else: + query = query_part.strip() + + if query: + queries.append(query) + + return queries + + def reward_function(completions: List[str], **kwargs) -> List[float]: + """ + Custom reward function for GRPO. + + For each completion: + 1. Parse queries from the trajectory + 2. Execute retrieval to get SigLIP scores + 3. Compute reward using the full reward function + """ + # Extract ground truth answers from kwargs + # GRPO passes dataset columns through kwargs + answers = kwargs.get("answer", []) + images = kwargs.get("image", []) + datasets = kwargs.get("dataset", []) + ids = kwargs.get("id", []) + + if not answers: + # Fallback: return zero rewards if no ground truth + print("WARNING: No ground truth answers found in kwargs") + return [0.0] * len(completions) + + rewards = [] + siglip_scores_list = [] + timeouts = [] + + for completion, answer, image, dataset_name, sample_id in zip( + completions, answers, images, datasets, ids + ): + # Convert completion to string + # For conversational format, completion is a list like [{"role": "assistant", "content": "..."}] + if isinstance(completion, list) and len(completion) > 0: + if isinstance(completion[0], dict) and "content" in completion[0]: + completion_text = completion[0]["content"] + else: + completion_text = str(completion) + else: + completion_text = str(completion) + + # Parse retrieval queries from completion + retrieval_queries = parse_retrieval_queries(completion_text) + + # Execute retrieval for each query to get SigLIP scores + siglip_scores = [] + if len(retrieval_queries) > 0: + for query_text in retrieval_queries: + try: + # Retrieve top-1 similar image using text query + retrieved_candidate, _ = environment.retrieve_top1( + query_text=query_text, + dataset_name=dataset_name, + exclude_id=sample_id + ) + + # Compute image-to-image similarity for reward + img_similarity = environment.compute_image_similarity( + query_image=image, + retrieved_image=retrieved_candidate["image"] + ) + siglip_scores.append(img_similarity) + except Exception as e: + print(f"Retrieval error for query '{query_text}': {e}") + siglip_scores.append(0.0) + + # Check for timeout (exceeded max turns without ) + num_rets = completion_text.count("") + has_answer = "" in completion_text + timeout = (num_rets >= environment.max_turns and not has_answer) + + siglip_scores_list.append(siglip_scores) + timeouts.append(timeout) + + # Compute rewards using our custom function + # Convert completions to strings for batch_compute_rewards + completion_texts = [] + for comp in completions: + if isinstance(comp, list) and len(comp) > 0: + if isinstance(comp[0], dict) and "content" in comp[0]: + completion_texts.append(comp[0]["content"]) + else: + completion_texts.append(str(comp)) + else: + completion_texts.append(str(comp)) + + rewards = batch_compute_rewards( + trajectories=completion_texts, + ground_truths=answers, + siglip_scores_list=siglip_scores_list, + max_turns=environment.max_turns, + timeouts=timeouts + ) + + return rewards + + return reward_function + + +def load_models( + model_path: str, + siglip_model_name: str = "/workspace/siglip2-so400m-patch16-naflex", + device: str = "cuda" +): + """Load VLM and SigLIP models.""" + print(f"Loading VLM from {model_path}...") + + # Import the specific Qwen3VL generation class + from transformers.models.qwen3_vl import Qwen3VLForConditionalGeneration + from transformers import AutoModel + + model = Qwen3VLForConditionalGeneration.from_pretrained( + model_path, + dtype=torch.bfloat16, + device_map="auto", + trust_remote_code=True, + attn_implementation="flash_attention_2" + ) + + # Use processor instead of tokenizer for vision-language models + processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) + + print(f"Loading SigLIP from {siglip_model_name}...") + siglip_model = AutoModel.from_pretrained(siglip_model_name, device_map="auto") + siglip_processor = AutoProcessor.from_pretrained(siglip_model_name) + + return model, processor, siglip_model, siglip_processor + + +def prepare_dataset_for_grpo(samples: List[Dict]) -> Dataset: + """ + Convert samples to HuggingFace Dataset format for GRPO. + GRPO expects simple chat messages (strings), and will handle image insertion automatically. + """ + dataset_dict = { + "prompt": [], + "image": [], + "answer": [], + "dataset": [], + "id": [], + "needs_retrieval": [] + } + + for sample in samples: + # Create simple chat messages - GRPO will handle image insertion + messages = [ + { + "role": "system", + "content": ( + "You are a visual question answering assistant. " + "You can either answer directly using or retrieve similar images using .\n" + "- Use followed by a description when you need to see similar examples.\n" + "- Use followed by your answer when you're ready to answer." + ) + }, + { + "role": "user", + "content": f"Question: {sample['question']}" + } + ] + + dataset_dict["prompt"].append(messages) + dataset_dict["image"].append(sample["image"]) + dataset_dict["answer"].append(sample["answer"]) + dataset_dict["dataset"].append(sample["dataset"]) + dataset_dict["id"].append(sample["id"]) + dataset_dict["needs_retrieval"].append(sample["needs_retrieval"]) + + # Create dataset without using from_dict to avoid PyArrow issues + from datasets import Dataset as HFDataset + + # Convert to list of dicts format + data_list = [] + for i in range(len(dataset_dict["prompt"])): + data_list.append({ + "prompt": dataset_dict["prompt"][i], + "image": dataset_dict["image"][i], + "answer": dataset_dict["answer"][i], + "dataset": dataset_dict["dataset"][i], + "id": dataset_dict["id"][i], + "needs_retrieval": dataset_dict["needs_retrieval"][i] + }) + + return HFDataset.from_list(data_list) + + +def main(): + # Configuration + MODEL_PATH = "/workspace/xiaobin/SFT_model/hf_qwen3vl_siglip_vqa_iter_0000881" + OUTPUT_DIR = "/workspace/xiaobin/RL/output" + SIGLIP_MODEL = "/workspace/siglip2-so400m-patch16-naflex" + RL_DATASET_PATH = "/workspace/xiaobin/RL_data/rl_train.jsonl" # Pre-built dataset + + # Training hyperparameters - Optimized for speed with more GPU memory + LEARNING_RATE = 5e-6 # Reduced from 1e-5 to slow down collapse + BATCH_SIZE = 64 # Increased from 32 + NUM_GENERATIONS = 4 # 64/4=16 unique prompts per step + MAX_TURNS = 3 + NUM_EPOCHS = 3 + MAX_PROMPT_LENGTH = 512 + MAX_COMPLETION_LENGTH = 256 + + # Data parameters + MAX_SAMPLES = None # Use all 10000 samples + USE_WANDB = False # Disabled wandb + + # Initialize wandb + if USE_WANDB: + wandb.init( + project="grpo-retrieval-vqa", + config={ + "model": MODEL_PATH, + "learning_rate": LEARNING_RATE, + "batch_size": BATCH_SIZE, + "num_generations": NUM_GENERATIONS, + "max_turns": MAX_TURNS, + } + ) + + # Load models + model, processor, siglip_model, siglip_processor = load_models( + MODEL_PATH, + SIGLIP_MODEL + ) + + # Ensure model has generation_config + from transformers import GenerationConfig + if not hasattr(model, 'generation_config') or model.generation_config is None: + model.generation_config = GenerationConfig.from_model_config(model.config) + print("Created default generation_config for model") + + # Load and prepare data + print("\n" + "="*80) + print("LOADING DATA") + print("="*80) + + import json + all_samples = [] + with open(RL_DATASET_PATH, 'r', encoding='utf-8') as f: + for line in f: + sample = json.loads(line) + # Convert to expected format + all_samples.append({ + 'id': sample['id'], + 'image': sample['image'], + 'question': sample['question'], + 'answer': sample['answer'], + 'dataset': sample['subdir'], + 'needs_retrieval': sample['category'] == 'positive' + }) + + # Limit samples if specified + if MAX_SAMPLES is not None and len(all_samples) > MAX_SAMPLES: + import random + random.seed(42) + all_samples = random.sample(all_samples, MAX_SAMPLES) + + print(f"Loaded {len(all_samples)} samples from {RL_DATASET_PATH}") + + # Build candidate pools for retrieval + positive_samples = [s for s in all_samples if s['needs_retrieval']] + negative_samples = [s for s in all_samples if not s['needs_retrieval']] + candidate_pools = build_candidate_pools(positive_samples, negative_samples) + + # Convert to HF Dataset + train_dataset = prepare_dataset_for_grpo(all_samples) + + # Initialize retrieval environment + print("\n" + "="*80) + print("INITIALIZING ENVIRONMENT") + print("="*80) + environment = RetrievalEnvironment( + siglip_model=siglip_model, + siglip_processor=siglip_processor, + candidate_pools=candidate_pools, + max_turns=MAX_TURNS, + device="cuda" + ) + + # Configure GRPO training + print("\n" + "="*80) + print("CONFIGURING GRPO TRAINER") + print("="*80) + training_args = GRPOConfig( + output_dir=OUTPUT_DIR, + learning_rate=LEARNING_RATE, + per_device_train_batch_size=BATCH_SIZE, + num_train_epochs=NUM_EPOCHS, + num_generations=NUM_GENERATIONS, + generation_batch_size=NUM_GENERATIONS * 16, # 4*16=64 + max_completion_length=MAX_COMPLETION_LENGTH, + logging_steps=10, + save_steps=100, + save_total_limit=3, + bf16=True, + gradient_accumulation_steps=1, + warmup_steps=50, + report_to="wandb" if USE_WANDB else "none", + gradient_checkpointing=True, + temperature=1.4, # Higher temp for exploration + top_p=0.95, # Wider sampling + max_grad_norm=1.0, # Explicit gradient clipping + lr_scheduler_type="constant_with_warmup", # Prevent lr decay stall + repetition_penalty=1.1, # Discourage identical generations + beta=0.04, # KL regularization to prevent policy collapse + epsilon=0.3, # Wider clip range (default 0.2) + ) + + # Create custom reward function + reward_func = create_reward_function(environment) + + # Initialize trainer with custom reward function + trainer = GRPOTrainer( + model=model, + args=training_args, + train_dataset=train_dataset, + processing_class=processor, + reward_funcs=reward_func, + ) + + # Start training + print("\n" + "="*80) + print("STARTING TRAINING") + print("="*80) + trainer.train() + + # Save final model + print("\n" + "="*80) + print("SAVING MODEL") + print("="*80) + final_output_dir = os.path.join(OUTPUT_DIR, "final_model") + trainer.save_model(final_output_dir) + print(f"Model saved to {final_output_dir}") + + if USE_WANDB: + wandb.finish() + + +if __name__ == "__main__": + main() diff --git a/ICL/RL/train_grpo_20260224_133510.log b/ICL/RL/train_grpo_20260224_133510.log new file mode 100644 index 0000000000000000000000000000000000000000..314dc3998c404fc5d31ea43712b6fb9f4307b533 --- /dev/null +++ b/ICL/RL/train_grpo_20260224_133510.log @@ -0,0 +1,327 @@ +Loading VLM from /workspace/xiaobin/SFT_model/hf_qwen3vl_siglip_vqa_iter_0000881... + Loading weights: 0%| | 0/750 [00:00- + If you use this software, please cite it using the + metadata from this file. +type: software +authors: + - given-names: Leandro + family-names: von Werra + - given-names: Younes + family-names: Belkada + - given-names: Lewis + family-names: Tunstall + - given-names: Edward + family-names: Beeching + - given-names: Tristan + family-names: Thrush + - given-names: Nathan + family-names: Lambert + - given-names: Shengyi + family-names: Huang + - given-names: Kashif + family-names: Rasul + - given-names: Quentin + family-names: Gallouédec +repository-code: 'https://github.com/huggingface/trl' +abstract: >- + TRL (Transformers Reinforcement Learning) is an + open-source toolkit for aligning transformer models via + post-training. It provides practical, scalable + implementations of SFT, reward modeling, DPO, and GRPO + within the Hugging Face ecosystem. +keywords: + - transformers + - reinforcement learning + - preference optimization + - language model alignment + - post-training +license: Apache-2.0 +version: '0.28' +date-released: '2020-03-27' diff --git a/ICL/RL/trl_source/CODE_OF_CONDUCT.md b/ICL/RL/trl_source/CODE_OF_CONDUCT.md new file mode 100644 index 0000000000000000000000000000000000000000..ef09fa1375a81440bf0733b659045453a5476c43 --- /dev/null +++ b/ICL/RL/trl_source/CODE_OF_CONDUCT.md @@ -0,0 +1,133 @@ + +# Contributor Covenant Code of Conduct + +## Our Pledge + +We as members, contributors, and leaders pledge to make participation in our +community a harassment-free experience for everyone, regardless of age, body +size, visible or invisible disability, ethnicity, sex characteristics, gender +identity and expression, level of experience, education, socio-economic status, +nationality, personal appearance, race, caste, color, religion, or sexual +identity and orientation. + +We pledge to act and interact in ways that contribute to an open, welcoming, +diverse, inclusive, and healthy community. + +## Our Standards + +Examples of behavior that contributes to a positive environment for our +community include: + +* Demonstrating empathy and kindness toward other people +* Being respectful of differing opinions, viewpoints, and experiences +* Giving and gracefully accepting constructive feedback +* Accepting responsibility and apologizing to those affected by our mistakes, + and learning from the experience +* Focusing on what is best not just for us as individuals, but for the overall + community + +Examples of unacceptable behavior include: + +* The use of sexualized language or imagery, and sexual attention or advances of + any kind +* Trolling, insulting or derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or email address, + without their explicit permission +* Other conduct which could reasonably be considered inappropriate in a + professional setting + +## Enforcement Responsibilities + +Community leaders are responsible for clarifying and enforcing our standards of +acceptable behavior and will take appropriate and fair corrective action in +response to any behavior that they deem inappropriate, threatening, offensive, +or harmful. + +Community leaders have the right and responsibility to remove, edit, or reject +comments, commits, code, wiki edits, issues, and other contributions that are +not aligned to this Code of Conduct, and will communicate reasons for moderation +decisions when appropriate. + +## Scope + +This Code of Conduct applies within all community spaces, and also applies when +an individual is officially representing the community in public spaces. +Examples of representing our community include using an official e-mail address, +posting via an official social media account, or acting as an appointed +representative at an online or offline event. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported to the community leaders responsible for enforcement at +feedback@huggingface.co. +All complaints will be reviewed and investigated promptly and fairly. + +All community leaders are obligated to respect the privacy and security of the +reporter of any incident. + +## Enforcement Guidelines + +Community leaders will follow these Community Impact Guidelines in determining +the consequences for any action they deem in violation of this Code of Conduct: + +### 1. Correction + +**Community Impact**: Use of inappropriate language or other behavior deemed +unprofessional or unwelcome in the community. + +**Consequence**: A private, written warning from community leaders, providing +clarity around the nature of the violation and an explanation of why the +behavior was inappropriate. A public apology may be requested. + +### 2. Warning + +**Community Impact**: A violation through a single incident or series of +actions. + +**Consequence**: A warning with consequences for continued behavior. No +interaction with the people involved, including unsolicited interaction with +those enforcing the Code of Conduct, for a specified period of time. This +includes avoiding interactions in community spaces as well as external channels +like social media. Violating these terms may lead to a temporary or permanent +ban. + +### 3. Temporary Ban + +**Community Impact**: A serious violation of community standards, including +sustained inappropriate behavior. + +**Consequence**: A temporary ban from any sort of interaction or public +communication with the community for a specified period of time. No public or +private interaction with the people involved, including unsolicited interaction +with those enforcing the Code of Conduct, is allowed during this period. +Violating these terms may lead to a permanent ban. + +### 4. Permanent Ban + +**Community Impact**: Demonstrating a pattern of violation of community +standards, including sustained inappropriate behavior, harassment of an +individual, or aggression toward or disparagement of classes of individuals. + +**Consequence**: A permanent ban from any sort of public interaction within the +community. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], +version 2.1, available at +[https://www.contributor-covenant.org/version/2/1/code_of_conduct.html][v2.1]. + +Community Impact Guidelines were inspired by +[Mozilla's code of conduct enforcement ladder][Mozilla CoC]. + +For answers to common questions about this code of conduct, see the FAQ at +[https://www.contributor-covenant.org/faq][FAQ]. Translations are available at +[https://www.contributor-covenant.org/translations][translations]. + +[homepage]: https://www.contributor-covenant.org +[v2.1]: https://www.contributor-covenant.org/version/2/1/code_of_conduct.html +[Mozilla CoC]: https://github.com/mozilla/diversity +[FAQ]: https://www.contributor-covenant.org/faq +[translations]: https://www.contributor-covenant.org/translations \ No newline at end of file diff --git a/ICL/RL/trl_source/CONTRIBUTING.md b/ICL/RL/trl_source/CONTRIBUTING.md new file mode 100644 index 0000000000000000000000000000000000000000..7229a2b8592001cfba6c6d6021930033a503b8db --- /dev/null +++ b/ICL/RL/trl_source/CONTRIBUTING.md @@ -0,0 +1,411 @@ +# How to contribute to TRL? + +Everyone is welcome to contribute, and we value everybody's contribution. Code contributions are not the only way to help the community. Answering questions, helping others, and improving the documentation are also immensely valuable. + +It also helps us if you spread the word! Reference the library in blog posts about the awesome projects it made possible, shout out on Twitter every time it has helped you, or simply ⭐️ the repository to say thank you. + +However you choose to contribute, please be mindful and respect our [code of conduct](https://github.com/huggingface/trl/blob/main/CODE_OF_CONDUCT.md). + +**This guide was heavily inspired by the awesome [scikit-learn guide to contributing](https://github.com/scikit-learn/scikit-learn/blob/main/CONTRIBUTING.md).** + +## Ways to contribute + +There are several ways you can contribute to TRL: + +* Fix outstanding issues with the existing code. +* Submit issues related to bugs or desired new features. +* Implement trainers for new post-training algorithms. +* Contribute to the examples or the documentation. + +If you don't know where to start, there is a special [Good First Issue](https://github.com/huggingface/trl/labels/%F0%9F%91%B6%20good%20first%20issue) listing. It will give you a list of open issues that are beginner-friendly and help you start contributing to open-source. The best way to do that is to open a Pull Request and link it to the issue that you'd like to work on. We try to give priority to opened PRs as we can easily track the progress of the fix, and if the contributor does not have time anymore, someone else can take the PR over. + +For something slightly more challenging, you can also take a look at the [Good Second Issue](https://github.com/huggingface/trl/labels/%F0%9F%A7%92%20good%20second%20issue) list. In general though, if you feel like you know what you're doing, go for it and we'll help you get there! 🚀 + +> All contributions are equally valuable to the community. 🥰 + +Before you start contributing make sure you have installed all the dev tools: + +```bash +pip install -e .[dev] +``` + +## Fixing outstanding issues + +If you notice an issue with the existing code and have a fix in mind, feel free to [start contributing](#submitting-a-pull-request-pr) and open a Pull Request! + +## Submitting a bug-related issue or feature request + +Do your best to follow these guidelines when submitting a bug-related issue or a feature request. It will make it easier for us to come back to you quickly and with good feedback. + +### Did you find a bug? + +The TRL library is robust and reliable thanks to users who report the problems they encounter. + +Before you report an issue, we would really appreciate it if you could **make sure the bug was not already reported** (use the search bar on GitHub under Issues). Your issue should also be related to bugs in the library itself, and not your code. + +Once you've confirmed the bug hasn't already been reported, please include the following information in your issue so we can quickly resolve it: + +* Your **OS type and version**, **Python**, **PyTorch**, **TRL** and **Transformers** versions. +* A short, self-contained, code snippet that allows us to reproduce the bug in less than 30s. +* The *full* traceback if an exception is raised. +* Attach any other additional information, like screenshots, you think may help. + +To get the OS and software versions automatically, run the following command: + +```bash +trl env +``` + +### Do you want a new feature? + +If there is a new feature you'd like to see in TRL, please open an issue and describe: + +1. What is the *motivation* behind this feature? Is it related to a problem or frustration with the library? Is it a feature related to something you need for a project? Is it something you worked on and think it could benefit the community? + + Whatever it is, we'd love to hear about it! + +2. Describe your requested feature in as much detail as possible. The more you can tell us about it, the better we'll be able to help you. +3. Provide a *code snippet* that demonstrates the feature's usage. +4. If the feature is related to a paper, please include a link. + +If your issue is well written we're already 80% of the way there by the time you create it. + +## Do you want to implement a new trainer? + +New post-training methods are published frequently and those that satisfy the following criteria are good candidates to be integrated into TRL: + +* **Simplicity:** Does the new method achieve similar performance as prior methods, but with less complexity? A good example is Direct Preference Optimization (DPO) [[Rafailov et al, 2023]](https://huggingface.co/papers/2305.18290), which provided a simpler and compelling alternative to RLHF methods. +* **Efficiency:** Does the new method provide a significant improvement in training efficiency? A good example is Odds Ratio Preference Optimization (ORPO) [[Hong et al, 2023]](https://huggingface.co/papers/2403.07691), which utilizes a similar objective as DPO but requires half the GPU VRAM. + +Methods that only provide incremental improvements at the expense of added complexity or compute costs are unlikely to be included in TRL. + +If you want to implement a trainer for a new post-training method, first open an issue and provide the following information: + +* A short description of the method and a link to the paper. +* Link to the implementation if it is open-sourced. +* Link to model weights trained with the method if they are available. + +Based on the community and maintainer feedback, the next step will be to implement the trainer and config classes. See the following examples for inspiration: + +* Paired preference optimisation: [`dpo_trainer.py`](./trl/trainer/dpo_trainer.py) and [`dpo_config.py`](./trl/trainer/dpo_config.py) +* RL-based optimisation: [`rloo_trainer.py`](./trl/trainer/rloo_trainer.py) and [`rloo_config.py`](./trl/trainer/rloo_config.py) +* Online optimisation: [`online_dpo_trainer.py`](./trl/trainer/online_dpo_trainer.py) and [`online_dpo_config.py`](./trl/trainer/online_dpo_config.py) + +## Do you want to add documentation? + +We're always looking for improvements to the documentation that make it more clear and accurate. Please let us know how the documentation can be improved, such as typos, dead links, and any missing, unclear, or inaccurate content... We'll be happy to make the changes or help you contribute if you're interested! + +## Submitting a pull request (PR) + +Before writing code, we strongly advise you to search through the existing PRs or issues to make sure that nobody is already working on the same thing. If you are unsure, it is always a good idea to open an issue to get some feedback. + +You will need basic `git` proficiency to be able to contribute to TRL. `git` is not the easiest tool to use but it has the greatest manual. Type `git --help` in a shell and enjoy. If you prefer books, [Pro Git](https://git-scm.com/book/en/v2) is a very good reference. + +Follow these steps to start contributing: + +1. Fork the [repository](https://github.com/huggingface/trl) by clicking on the 'Fork' button on the repository's page. This creates a copy of the code under your GitHub user account. + +2. Clone your fork to your local disk, and add the base repository as a remote. The following command assumes you have your public SSH key uploaded to GitHub. See the following guide for more [information](https://docs.github.com/en/repositories/creating-and-managing-repositories/cloning-a-repository). + + ```bash + git clone git@github.com:/trl.git + cd trl + git remote add upstream https://github.com/huggingface/trl.git + ``` + +3. Create a new branch to hold your development changes, and do this for every new PR you work on. + + Start by synchronizing your `main` branch with the `upstream/main` branch (more details in the [GitHub Docs](https://docs.github.com/en/github/collaborating-with-issues-and-pull-requests/syncing-a-fork)): + + ```bash + git checkout main + git fetch upstream + git merge upstream/main + ``` + + Once your `main` branch is synchronized, create a new branch from it: + + ```bash + git checkout -b a-descriptive-name-for-my-changes + ``` + + **Do not** work on the `main` branch. + +4. Set up a development environment by running the following command in a conda or a virtual environment you've created for working on this library: + + ```bash + pip install -e .[dev] + ``` + + (If TRL was already installed in the virtual environment, remove it with `pip uninstall trl` before reinstalling it.) + + Alternatively, if you are using [Visual Studio Code](https://code.visualstudio.com/Download), the fastest way to get set up is by using the provided Dev Container. Check [the documentation on how to get started with dev containers](https://code.visualstudio.com/docs/remote/containers). + +5. Develop the features on your branch. + + As you work on the features, you should make sure that the test suite passes. You should run the tests impacted by your changes like this (see below an explanation regarding the environment variable): + + ```bash + pytest tests/.py + ``` + + > For the following commands leveraging the `make` utility. + + You can also run the full suite with the following command. + + ```bash + make test + ``` + + TRL relies on `ruff` for maintaining consistent code formatting across its source files. Before submitting any PR, you should apply automatic style corrections and run code verification checks. + + We provide a `precommit` target in the `Makefile` that simplifies this process by running all required checks and optimizations on only the files modified by your PR. + + To apply these checks and corrections in one step, use: + + ```bash + make precommit + ``` + + This command runs the following: + + * Executes `pre-commit` hooks to automatically fix style issues with `ruff` and other tools. + * Runs additional scripts such as adding copyright information. + + If you prefer to apply the style corrections separately or review them individually, the `pre-commit` hook will handle the formatting for the files in question. + + Once you're happy with your changes, add changed files using `git add` and make a commit with `git commit` to record your changes locally: + + ```bash + git add modified_file.py + git commit + ``` + + Please write [good commit messages](https://chris.beams.io/posts/git-commit/). + + It is a good idea to sync your copy of the code with the original + repository regularly. This way you can quickly account for changes: + + ```bash + git fetch upstream + git rebase upstream/main + ``` + + Push the changes to your account using: + + ```bash + git push -u origin a-descriptive-name-for-my-changes + ``` + +6. Once you are satisfied (**and the checklist below is happy too**), go to the webpage of your fork on GitHub. Click on 'Pull request' to send your changes to the project maintainers for review. + +7. It's ok if maintainers ask you for changes. It happens to core contributors too! To ensure everyone can review your changes in the pull request, work on your local branch and push the updates to your fork. They will automatically appear in the pull request. + +### Checklist + +1. The title of your pull request should be a summary of its contribution; +2. If your pull request addresses an issue, please mention the issue number in the pull request description to make sure they are linked (and people consulting the issue know you are working on it); +3. To indicate a work in progress please prefix the title with `[WIP]`, or mark the PR as a draft PR. These are useful to avoid duplicated work, and to differentiate it from PRs ready to be merged; +4. Make sure existing tests pass; +5. Add high-coverage tests. No quality testing = no merge. + +### Tests + +An extensive test suite is included to test the library behavior and several examples. Library tests can be found in +the [tests folder](https://github.com/huggingface/trl/tree/main/tests). + +We use `pytest` to run the tests. From the root of the +repository here's how to run tests with `pytest` for the library: + +```bash +python -m pytest -sv ./tests +``` + +That's how `make test` is implemented (without the `pip install` line)! + +You can specify a smaller set of tests to test only the feature +you're working on. + +### Default values guidelines + +1. **Use defaults when appropriate**: + + Provide default values unless the parameter's value varies significantly by use case. For example, datasets or models should not have defaults, but parameters like `learning_rate` should. + +2. **Prioritize proven defaults**: + + Default values should align with those recommended in the original paper or method. Alternatives require strong evidence of superior performance in most cases. + +3. **Ensure safety and predictability**: + + Defaults must be safe, expected and reliable. Avoid settings that could lead to surprising outcomes, such as excessive memory usage or poor performance in edge cases. + +4. **Balance consistency and flexibility**: + + Aim for consistent defaults across similar functions or methods. However, consistency should not be preferred to point 2 or 3. + +5. **Opt-in for new features**: + + Do not enable new features or improvements (e.g., novel loss functions) by default. Users should explicitly opt-in to use these. + +### Writing documentation + +High-quality documentation is crucial for maintaining a project that is easy to use, understand, and extend. When adding new features, ensure they are thoroughly documented to maintain consistency and clarity throughout the project. + +To illustrate what good documentation looks like, here’s an example of a well-documented function: + +````python +def replicate_str(string: str, n: int, sep: str = " ") -> str: + r""" + Replicate a string `n` times with a separator. + + Args: + string (`str`): + String to replicate. + n (`int`): + Number of times to replicate the string. + sep (`str`, *optional*, defaults to `" "`): + Separator to use between each replication. + + Returns: + `str`: The replicated string. + + Examples: + ```python + >>> replicate_str("hello", 3) + "hello hello hello" + >>> replicate_str("hello", 3, sep=", ") + "hello, hello, hello" + ``` + """ + return sep.join([string] * n) +```` + +* **Line Wrapping:** Applied a consistent line wrap at column 120 to improve readability. +* **Definite Articles:** Removed definite articles where possible to streamline language. (Eg: Changed "The string to replicate" to "String to replicate") +* **Type Annotations:** + * Always include type definitions, indicating if a parameter is optional and specifying the default value. + +* **String Defaults:** + * Ensured that default string values are wrapped in double quotes: + + ```txt + defaults to `"foo"` + ``` + +* **Dictionary Typing:** + * Replaced generic `dict` type hints with more explicit `dict[str, Any]` to clarify expected key-value pairs. +* **Default Value Formatting:** + * Consistently surrounded default values with backticks for improved formatting: + + ```txt + defaults to `4` + ``` + +* **Sub-sectioning:** When the number of arguments is large, consider breaking them into sub-sections for better readability. + + ```python + def calculate_statistics(data: list[float], precision: int = 2, include_variance: bool = False) -> dict[str, float]: + r""" + Calculates basic statistics for a given dataset. + + Args: + > Data inputs + + data (`list[float]`): + A list of numerical values to analyze. + + > Configuration parameters + + precision (`int`, *optional*, defaults to `2`): + Number of decimal places to round the results. + include_variance (`bool`, *optional*, defaults to `False`): + Whether to include the variance of the dataset in the results. + + Returns: + `dict[str, float]`: + A dictionary containing calculated statistics such as mean, median, and optionally variance. + """ + ... + ``` + +### Deprecation and backward compatibility + +Our approach to deprecation and backward compatibility is flexible and based on the feature’s usage and impact. Each deprecation is carefully evaluated, aiming to balance innovation with user needs. + +When a feature or component is marked for deprecation, its use will emit a warning message. This warning will include: + +* **Transition Guidance**: Instructions on how to migrate to the alternative solution or replacement. +* **Removal Version**: The target version when the feature will be removed, providing users with a clear timeframe to transition. + +Example: + + ```python + warnings.warn( + "The `Trainer.foo` method is deprecated and will be removed in version 0.14.0. " + "Please use the `Trainer.bar` class instead.", + FutureWarning, + stacklevel=2, + ) + ``` + +The deprecation and removal schedule is based on each feature's usage and impact, with examples at two extremes: + +* **Experimental or Low-Use Features**: For a feature that is experimental or has limited usage, backward compatibility may not be maintained between releases. Users should therefore anticipate potential breaking changes from one version to the next. + +* **Widely-Used Components**: For a feature with high usage, we aim for a more gradual transition period of approximately **5 months**, generally scheduling deprecation around **5 minor releases** after the initial warning. + +These examples represent the two ends of a continuum. The specific timeline for each feature will be determined individually, balancing innovation with user stability needs. + +### Working with warnings + +Warnings play a critical role in guiding users toward resolving potential issues, but they should be used thoughtfully to avoid unnecessary noise. Unlike logging, which provides informational context or operational details, warnings signal conditions that require attention and action. Overusing warnings can dilute their importance, leading users to ignore them entirely. + +#### Definitions + +* **Correct**: An operation is correct if it is valid, follows the intended approach, and aligns with the current best practices or guidelines within the codebase. This is the recommended or intended way to perform the operation. +* **Supported**: An operation is supported if it is technically valid and works within the current codebase, but it may not be the most efficient, optimal, or recommended way to perform the task. This includes deprecated features or legacy approaches that still work but may be phased out in the future. + +#### Choosing the right message + +* **Correct → No warning**: + If the operation is fully valid and expected, no message should be issued. The system is working as intended, so no warning is necessary. + +* **Correct but deserves attention → No warning, possibly a log message**: + When an operation is correct but uncommon or requires special attention, providing an informational message can be helpful. This keeps users informed without implying any issue. If available, use the logger to output this message. Example: + + ```python + logger.info("This is an informational message about a rare but correct operation.") + ``` + +* **Correct but very likely a mistake → Warning with option to disable**: + In rare cases, you may want to issue a warning for a correct operation that’s very likely a mistake. In such cases, you must provide an option to suppress the warning. This can be done with a flag in the function. Example: + + ```python + def my_function(foo, bar, _warn=True): + if foo == bar: + if _warn: + logger.warning("foo and bar are the same, this is likely a mistake. Ignore this warning by setting `_warn=False`.") + # Do something + ``` + +* **Supported but not correct → Warning**: + If the operation is technically supported but is deprecated, suboptimal, or could cause future issues (e.g., conflicting arguments), a warning should be raised. This message should be actionable, meaning it must explain how to resolve the issue. Example: + + ```python + def my_function(foo, bar): + if foo and bar: + logger.warning("Both `foo` and `bar` were provided, but only one is allowed. Ignoring `foo`. Please pass only one of these arguments.") + # Do something + ``` + +* **Not supported → Exception**: + If the operation is invalid or unsupported, raise an exception. This indicates that the operation cannot be performed and requires immediate attention. Example: + + ```python + def my_function(foo, bar): + if foo and bar: + raise ValueError("Both `foo` and `bar` were provided, but only one is allowed. Please pass only one of these arguments.") + ``` + +By following this classification, you ensure that warnings, information, and exceptions are used appropriately, providing clear guidance to the user without cluttering the system with unnecessary messages. diff --git a/ICL/RL/trl_source/LICENSE b/ICL/RL/trl_source/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..f577b7741dfb5c6af119f250971549dd7750acb9 --- /dev/null +++ b/ICL/RL/trl_source/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2020-2026 The HuggingFace Team + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/ICL/RL/trl_source/MANIFEST.in b/ICL/RL/trl_source/MANIFEST.in new file mode 100644 index 0000000000000000000000000000000000000000..595e7a1505383b3d9753d92b8e415c69dda5c550 --- /dev/null +++ b/ICL/RL/trl_source/MANIFEST.in @@ -0,0 +1,7 @@ +include LICENSE +include CONTRIBUTING.md +include README.md +include trl/accelerate_configs/*.yaml +include trl/templates/*.md +recursive-exclude * __pycache__ +prune tests diff --git a/ICL/RL/trl_source/Makefile b/ICL/RL/trl_source/Makefile new file mode 100644 index 0000000000000000000000000000000000000000..c84591a89553c9c7c3fcd04af761afee9159d5a9 --- /dev/null +++ b/ICL/RL/trl_source/Makefile @@ -0,0 +1,19 @@ +.PHONY: test precommit common_tests slow_tests tests_gpu test_experimental + +check_dirs := examples tests trl + +ACCELERATE_CONFIG_PATH = `pwd`/examples/accelerate_configs + +test: + pytest -n auto -m "not slow and not low_priority" -s -v --reruns 5 --reruns-delay 1 --only-rerun '(OSError|Timeout|HTTPError.*502|HTTPError.*504||not less than or equal to 0.01)' tests + +precommit: + python scripts/add_copyrights.py + pre-commit run --all-files + doc-builder style trl tests docs/source --max_len 119 + +slow_tests: + pytest -m "slow" tests/ $(if $(IS_GITHUB_CI),--report-log "slow_tests.log",) + +test_experimental: + pytest -n auto -s -v tests/experimental diff --git a/ICL/RL/trl_source/README.md b/ICL/RL/trl_source/README.md new file mode 100644 index 0000000000000000000000000000000000000000..98cfe0e30b442361d738b12cba1a998955e7ce3b --- /dev/null +++ b/ICL/RL/trl_source/README.md @@ -0,0 +1,207 @@ +# TRL - Transformer Reinforcement Learning + +
+ TRL Banner +
+ +

+ +

+

A comprehensive library to post-train foundation models

+

+ +

+ License + Documentation + GitHub release + Hugging Face Hub +

+ +## 🎉 What's New + +**OpenEnv Integration:** TRL now supports **[OpenEnv](https://huggingface.co/blog/openenv)**, the open-source framework from Meta for defining, deploying, and interacting with environments in reinforcement learning and agentic workflows. + +Explore how to seamlessly integrate TRL with OpenEnv in our [dedicated documentation](https://huggingface.co/docs/trl/openenv). + +## Overview + +TRL is a cutting-edge library designed for post-training foundation models using advanced techniques like Supervised Fine-Tuning (SFT), Group Relative Policy Optimization (GRPO), and Direct Preference Optimization (DPO). Built on top of the [🤗 Transformers](https://github.com/huggingface/transformers) ecosystem, TRL supports a variety of model architectures and modalities, and can be scaled-up across various hardware setups. + +## Highlights + +- **Trainers**: Various fine-tuning methods are easily accessible via trainers like [`SFTTrainer`](https://huggingface.co/docs/trl/sft_trainer), [`GRPOTrainer`](https://huggingface.co/docs/trl/grpo_trainer), [`DPOTrainer`](https://huggingface.co/docs/trl/dpo_trainer), [`RewardTrainer`](https://huggingface.co/docs/trl/reward_trainer) and more. + +- **Efficient and scalable**: + - Leverages [🤗 Accelerate](https://github.com/huggingface/accelerate) to scale from single GPU to multi-node clusters using methods like [DDP](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html) and [DeepSpeed](https://github.com/deepspeedai/DeepSpeed). + - Full integration with [🤗 PEFT](https://github.com/huggingface/peft) enables training on large models with modest hardware via quantization and LoRA/QLoRA. + - Integrates [🦥 Unsloth](https://github.com/unslothai/unsloth) for accelerating training using optimized kernels. + +- **Command Line Interface (CLI)**: A simple interface lets you fine-tune with models without needing to write code. + +## Installation + +### Python Package + +Install the library using `pip`: + +```bash +pip install trl +``` + +### From source + +If you want to use the latest features before an official release, you can install TRL from source: + +```bash +pip install git+https://github.com/huggingface/trl.git +``` + +### Repository + +If you want to use the examples you can clone the repository with the following command: + +```bash +git clone https://github.com/huggingface/trl.git +``` + +## Quick Start + +For more flexibility and control over training, TRL provides dedicated trainer classes to post-train language models or PEFT adapters on a custom dataset. Each trainer in TRL is a light wrapper around the 🤗 Transformers trainer and natively supports distributed training methods like DDP, DeepSpeed ZeRO, and FSDP. + +### `SFTTrainer` + +Here is a basic example of how to use the [`SFTTrainer`](https://huggingface.co/docs/trl/sft_trainer): + +```python +from trl import SFTTrainer +from datasets import load_dataset + +dataset = load_dataset("trl-lib/Capybara", split="train") + +trainer = SFTTrainer( + model="Qwen/Qwen2.5-0.5B", + train_dataset=dataset, +) +trainer.train() +``` + +### `GRPOTrainer` + +[`GRPOTrainer`](https://huggingface.co/docs/trl/grpo_trainer) implements the [Group Relative Policy Optimization (GRPO) algorithm](https://huggingface.co/papers/2402.03300) that is more memory-efficient than PPO and was used to train [Deepseek AI's R1](https://huggingface.co/deepseek-ai/DeepSeek-R1). + +```python +from datasets import load_dataset +from trl import GRPOTrainer +from trl.rewards import accuracy_reward + +dataset = load_dataset("trl-lib/DeepMath-103K", split="train") + +trainer = GRPOTrainer( + model="Qwen/Qwen2.5-0.5B-Instruct", + reward_funcs=accuracy_reward, + train_dataset=dataset, +) +trainer.train() +``` + +> [!NOTE] +> For reasoning models, use the `reasoning_accuracy_reward()` function for better results. + +### `DPOTrainer` + +[`DPOTrainer`](https://huggingface.co/docs/trl/dpo_trainer) implements the popular [Direct Preference Optimization (DPO) algorithm](https://huggingface.co/papers/2305.18290) that was used to post-train [Llama 3](https://huggingface.co/papers/2407.21783) and many other models. Here is a basic example of how to use the `DPOTrainer`: + +```python +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer +from trl import DPOConfig, DPOTrainer + +model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") +tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") +dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train") +training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO") +trainer = DPOTrainer( + model=model, + args=training_args, + train_dataset=dataset, + processing_class=tokenizer +) +trainer.train() +``` + +### `RewardTrainer` + +Here is a basic example of how to use the [`RewardTrainer`](https://huggingface.co/docs/trl/reward_trainer): + +```python +from trl import RewardTrainer +from datasets import load_dataset + +dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train") + +trainer = RewardTrainer( + model="Qwen/Qwen2.5-0.5B-Instruct", + train_dataset=dataset, +) +trainer.train() +``` + +## Command Line Interface (CLI) + +You can use the TRL Command Line Interface (CLI) to quickly get started with post-training methods like Supervised Fine-Tuning (SFT) or Direct Preference Optimization (DPO): + +**SFT:** + +```bash +trl sft --model_name_or_path Qwen/Qwen2.5-0.5B \ + --dataset_name trl-lib/Capybara \ + --output_dir Qwen2.5-0.5B-SFT +``` + +**DPO:** + +```bash +trl dpo --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \ + --dataset_name argilla/Capybara-Preferences \ + --output_dir Qwen2.5-0.5B-DPO +``` + +Read more about CLI in the [relevant documentation section](https://huggingface.co/docs/trl/clis) or use `--help` for more details. + +## Development + +If you want to contribute to `trl` or customize it to your needs make sure to read the [contribution guide](https://github.com/huggingface/trl/blob/main/CONTRIBUTING.md) and make sure you make a dev install: + +```bash +git clone https://github.com/huggingface/trl.git +cd trl/ +pip install -e .[dev] +``` + +## Experimental + +A minimal incubation area is available under `trl.experimental` for unstable / fast-evolving features. Anything there may change or be removed in any release without notice. + +Example: + +```python +from trl.experimental.new_trainer import NewTrainer +``` + +Read more in the [Experimental docs](https://huggingface.co/docs/trl/experimental_overview). + +## Citation + +```bibtex +@software{vonwerra2020trl, + title = {{TRL: Transformers Reinforcement Learning}}, + author = {von Werra, Leandro and Belkada, Younes and Tunstall, Lewis and Beeching, Edward and Thrush, Tristan and Lambert, Nathan and Huang, Shengyi and Rasul, Kashif and Gallouédec, Quentin}, + license = {Apache-2.0}, + url = {https://github.com/huggingface/trl}, + year = {2020} +} +``` + +## License + +This repository's source code is available under the [Apache-2.0 License](LICENSE). diff --git a/ICL/RL/trl_source/RELEASE.md b/ICL/RL/trl_source/RELEASE.md new file mode 100644 index 0000000000000000000000000000000000000000..67087ad7ac8b1d2f7a5ea96a705944eda723fc03 --- /dev/null +++ b/ICL/RL/trl_source/RELEASE.md @@ -0,0 +1,167 @@ +# Making a release + +> [!NOTE] +> VERSION needs to be formatted following the `v{major}.{minor}.{patch}` convention. We need to follow this convention to be able to retrieve versioned scripts. + +## Major/Minor Release + +### 1. Ensure your local repository is up to date with the upstream repository + +```bash +git checkout main +git pull origin main +``` + +> [!WARNING] +> Do not merge other pull requests into `main` until the release is done. This is to ensure that the release is stable and does not include any untested changes. Announce internally (#trl-internal) to other maintainers that you are doing a release and that they must not merge PRs until the release is done. + +### 2. Create a release branch from main + +```bash +git checkout -b release-v{major}.{minor} +``` + +### 3. Change the version in the following files + +- `.github/workflows/tests_latest.yml`: + + ```diff + - with: { ref: v{major}.{minor-1}-release } + + with: { ref: v{major}.{minor}-release } + ``` + +- `CITATION.cff` + + ```diff + - version: '{major}.{minor-1}' + + version: '{major}.{minor}' + ``` + +- `VERSION` + + ```diff + - {major}.{minor}.0.dev0 + + {major}.{minor}.0 + ``` + +### 4. Commit and push these changes + +```shell +git add .github/workflows/tests_latest.yml CITATION.cff VERSION +git commit -m 'Release: {major}.{minor}' +git push origin release-v{major}.{minor} +``` + +### 5. Create a pull request + +from `release-v{major}.{minor}` to `main`, named `Release: v{major}.{minor}`, wait for tests to pass, and request a review. + +### 6. Once the pull request is approved, merge it into `main` + +It will automatically publish the new version of the package on PyPI. + +### 7. Add a tag in git to mark the release + +```shell +git checkout main +git pull origin main +git tag -a v{major}.{minor}.0 -m 'Adds tag v{major}.{minor}.0 for PyPI' +git push origin v{major}.{minor}.0 +``` + +### 8. Create a branch `v{major}.{minor}-release` for future patch releases + +```shell +git checkout -b v{major}.{minor}-release +git push origin v{major}.{minor}-release +``` + +This ensures that future patch releases (`v{major}.{minor}.1`, `v{major}.{minor}.2`, etc.) can be made separately from `main`. + +### 9. Create a GitHub Release + +1. Go to the repo’s [releases section](https://github.com/huggingface/trl/releases) on GitHub. +2. Click **Draft a new release**. +3. Select the `v{major}.{minor}.0` tag you just created in step 7. +4. Add a title (`v{major}.{minor}.0`) and a short description of what’s new. +5. Click **Publish Release**. + +### 10. Bump to dev version + +1. Create a branch `bump-dev-version-{major}.{minor+1}` from `main` and checkout to it. + + ```shell + git checkout -b bump-dev-version-{major}.{minor+1} + ``` + +2. Change the version in file `VERSION`: + + ```diff + - {major}.{minor}.0 + + {major}.{minor+1}.0.dev0 + ``` + +3. Commit and push these changes + + ```shell + git add VERSION + git commit -m '⬆️ Bump dev version' + git push origin bump-dev-version-{major}.{minor+1} + ``` + +4. Create a pull request from `bump-dev-version-{major}.{minor+1}` to `main`, named `⬆️ Bump dev version`, and request urgent review. + +5. Once the pull request is approved, merge it into `main`. + +6. The codebase is now ready for the next development cycle, inform the team in the #trl-internal channel. + +## Making a patch release + +### 1. Ensure your local repository is up to date with the upstream repository + +```bash +git checkout v{major}.{minor}-release +git pull origin main +``` + +### 2. Cherry-pick the changes you want to include in the patch release + +```bash +git cherry-pick +git cherry-pick +... +``` + +### 3. Change the version in the file `VERSION` + +```diff +- {major}.{minor}.{patch-1} ++ {major}.{minor}.{patch} +``` + +### 4. Commit and push these changes + +```shell +git add VERSION +git commit -m 'Release: {major}.{minor}.{patch}' +git push origin v{major}.{minor}-release +``` + +### 5. Wait for the CI to pass + +The CI will automatically publish the new version of the package on PyPI. + +### 6. Add a tag in git to mark the release + +```shell +git tag -a v{major}.{minor}.{patch} -m 'Adds tag v{major}.{minor}.{patch} for PyPI' +git push origin v{major}.{minor}.{patch} +``` + +#### 7. Create a GitHub Release + +1. Go to the repo’s [releases section](https://github.com/huggingface/trl/releases) on GitHub. +2. Click **Draft a new release**. +3. Select the `v{major}.{minor}.{patch}` tag you just created in step 7. +4. Add a title (`v{major}.{minor}.{patch}`) and a short description of what’s new. +5. Click **Publish Release**. diff --git a/ICL/RL/trl_source/VERSION b/ICL/RL/trl_source/VERSION new file mode 100644 index 0000000000000000000000000000000000000000..834eaf6bf8d02da86d4fa3c0b724ad4a52e7f75e --- /dev/null +++ b/ICL/RL/trl_source/VERSION @@ -0,0 +1 @@ +0.29.0.dev0 \ No newline at end of file diff --git a/ICL/RL/trl_source/pyproject.toml b/ICL/RL/trl_source/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..7694c59eb0d65b28e51046dcece99ce6dd34080e --- /dev/null +++ b/ICL/RL/trl_source/pyproject.toml @@ -0,0 +1,194 @@ +[build-system] +requires = ["setuptools >= 77.0.3"] +build-backend = "setuptools.build_meta" + +[project] +name = "trl" +description = "Train transformer language models with reinforcement learning." +authors = [ + { name = "Leandro von Werra", email = "leandro.vonwerra@gmail.com" } +] +readme = { file = "README.md", content-type = "text/markdown" } +license = "Apache-2.0" +license-files = ["LICENSE"] +keywords = [ + "transformers", "huggingface", "language modeling", "post-training", "rlhf", "sft", "dpo", "grpo" +] +classifiers = [ + "Development Status :: 2 - Pre-Alpha", + "Intended Audience :: Developers", + "Intended Audience :: Science/Research", + "Natural Language :: English", + "Operating System :: OS Independent", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13" +] +requires-python = ">=3.10" +dependencies = [ + "accelerate>=1.4.0", + "datasets>=3.0.0", + "packaging>20.0", + "transformers>=4.56.2", +] +dynamic = ["version"] + +[project.urls] +Homepage = "https://github.com/huggingface/trl" + +[project.scripts] +trl = "trl.cli:main" + +[project.optional-dependencies] +bco = [ + "scikit-learn", + "joblib" +] +deepspeed = [ + "deepspeed>=0.14.4", + "transformers!=5.1.0", # see transformers#43780 +] +judges = [ + "openai>=1.23.2", + "llm-blender>=0.0.2", + "transformers<5.0.0", # see #4918 +] +kernels = [ + "kernels" +] +liger = [ + "liger-kernel>=0.6.4" +] +peft = [ + "peft>=0.8.0" +] +quality = [ + "pre-commit", + "hf-doc-builder" +] +quantization = [ + "bitsandbytes" +] +scikit = [ + "scikit-learn" +] +test = [ + "pytest-cov", + "pytest-datadir>=1.7.0", # lazy datadirs + "pytest-rerunfailures==15.1", + "pytest-xdist", + "pytest" +] +vllm = [ + "vllm>=0.10.2,<0.13.0", + "fastapi", + "pydantic", + "requests", + "uvicorn" +] +vlm = [ + "Pillow", + "torchvision", + "num2words==0.5.14" +] +math_verify = [ + "math-verify>=0.5.2", +] +dev = [ + # bco + "scikit-learn", + "joblib", + # deepspeed + "deepspeed>=0.14.4", + # judges + "openai>=1.23.2", + "llm-blender>=0.0.2", + # kernels + "kernels", + # liger + "liger-kernel>=0.6.4", + # peft + "peft>=0.8.0", + # quality + "pre-commit", + "hf-doc-builder", + # quantization + "bitsandbytes", + # scikit: included in bco + # test + "pytest-cov", + "pytest-datadir>=1.7.0", # lazy datadirs + "pytest-rerunfailures==15.1", + "pytest-xdist", + "pytest", + # vllm: not included in dev by default due to CUDA error; see GH-4228 + # vlm + "Pillow", + "torchvision", + "num2words==0.5.14", + # for response parsing (required for training with tools) + "jmespath", +] + +[tool.setuptools] +package-dir = {"trl" = "trl"} + +[tool.setuptools.dynamic] +version = { file = "VERSION" } + +[tool.coverage.run] +branch = true + +[tool.ruff] +target-version = "py310" +line-length = 119 +src = ["trl"] + +[tool.ruff.lint] +ignore = [ + "B028", # warning without explicit stacklevel + "C408", # dict() calls (stylistic) + "C901", # function complexity + "E501", +] +extend-select = ["E", "F", "I", "W", "UP", "B", "T", "C"] + +[tool.ruff.lint.per-file-ignores] +# Allow prints in auxiliary scripts +"examples/**.py" = ["T201"] +"scripts/**.py" = ["T201"] +# Ignore import violations in all `__init__.py` files. +"__init__.py" = ["F401"] + +[tool.ruff.lint.isort] +lines-after-imports = 2 +known-first-party = ["trl"] + +[tool.pytest.ini_options] +markers = [ + "slow: marks tests as slow (deselect with '-m \"not slow\"')", + "low_priority: marks tests as low priority (deselect with '-m \"not low_priority\"')" +] +norecursedirs = [ + "tests/experimental", +] +filterwarnings = [ + # SWIG deprecations from SWIG-generated C/C++ extensions: sentencepiece + # Upstream issue: https://github.com/google/sentencepiece/issues/1150 + "ignore:builtin type SwigPyPacked has no __module__ attribute:DeprecationWarning", + "ignore:builtin type SwigPyObject has no __module__ attribute:DeprecationWarning", + "ignore:builtin type swigvarlink has no __module__ attribute:DeprecationWarning", + + # PyTorch JIT deprecations (upstream, not actionable in TRL) + # Upstream issue: https://github.com/deepspeedai/DeepSpeed/issues/7835 + "ignore:`torch.jit.script_method` is deprecated:DeprecationWarning", + "ignore:`torch.jit.script` is deprecated:DeprecationWarning", + + # PyTorch DataLoader pin_memory device argument deprecations + # Triggered internally by torch.utils.data, not by our code + # Upstream issue: https://github.com/pytorch/pytorch/issues/174546 + "ignore:The argument 'device' of Tensor.pin_memory:DeprecationWarning", + "ignore:The argument 'device' of Tensor.is_pinned:DeprecationWarning", +] diff --git a/ICL/RL/trl_source/requirements.txt b/ICL/RL/trl_source/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..b78fda1b76e6a12d13b05a307b5adb6ae07d7084 --- /dev/null +++ b/ICL/RL/trl_source/requirements.txt @@ -0,0 +1,3 @@ +accelerate>=1.4.0 +datasets>=3.0.0 +transformers>=4.56.2 \ No newline at end of file diff --git a/ICL/RL/trl_source/tests/__init__.py b/ICL/RL/trl_source/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d2777dd0eb21a0ea67dce8775337a27ea8499da2 --- /dev/null +++ b/ICL/RL/trl_source/tests/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/ICL/RL/trl_source/tests/conftest.py b/ICL/RL/trl_source/tests/conftest.py new file mode 100644 index 0000000000000000000000000000000000000000..f071b789ffd082ed66087339c67ac7e6fab50700 --- /dev/null +++ b/ICL/RL/trl_source/tests/conftest.py @@ -0,0 +1,89 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +from functools import wraps + +import pytest +import torch + + +# ============================================================================ +# Model Revision Override +# ============================================================================ +# To test a tiny model PR before merging to main: +# 1. Add the full model_id and PR revision to this dict +# 2. Commit and push to trigger CI +# 3. Once CI is green, merge the tiny model PR on HF Hub +# 4. Remove the entry from this dict and commit +# +# Example: +# MODEL_REVISIONS = { +# "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5": "refs/pr/3", +# "trl-internal-testing/tiny-LlavaForConditionalGeneration": "refs/pr/5", +# } +# ============================================================================ + +MODEL_REVISIONS = { + # Add model_id: revision mappings here to test PRs +} + + +@pytest.fixture(autouse=True) +def apply_model_revisions(monkeypatch): + """Auto-inject revision parameter for models defined in MODEL_REVISIONS.""" + if not MODEL_REVISIONS: + return + + from transformers import PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin + + def create_classmethod_wrapper(original_classmethod): + # Extract the underlying function from the classmethod + original_func = original_classmethod.__func__ + + @wraps(original_func) + def wrapper(cls, pretrained_model_name_or_path, *args, **kwargs): + # Direct lookup: only inject if model_id is in the override dict + if pretrained_model_name_or_path in MODEL_REVISIONS: + if "revision" not in kwargs: + kwargs["revision"] = MODEL_REVISIONS[pretrained_model_name_or_path] + + return original_func(cls, pretrained_model_name_or_path, *args, **kwargs) + + # Re-wrap as classmethod + return classmethod(wrapper) + + # Patch all transformers Auto* classes + for cls in [ + PreTrainedModel, + PreTrainedTokenizerBase, + ProcessorMixin, + ]: + monkeypatch.setattr(cls, "from_pretrained", create_classmethod_wrapper(cls.from_pretrained)) + + +@pytest.fixture(autouse=True) +def cleanup_gpu(): + """ + Automatically cleanup GPU memory after each test. + + This fixture helps prevent CUDA out of memory errors when running tests in parallel with pytest-xdist by ensuring + models and tensors are properly garbage collected and GPU memory caches are cleared between tests. + """ + yield + # Cleanup after test + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() diff --git a/ICL/RL/trl_source/tests/test_activation_offloading.py b/ICL/RL/trl_source/tests/test_activation_offloading.py new file mode 100644 index 0000000000000000000000000000000000000000..35d8432d81d13f28e4bc1072dd30ae5228770d93 --- /dev/null +++ b/ICL/RL/trl_source/tests/test_activation_offloading.py @@ -0,0 +1,209 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from torch import nn +from transformers import AutoModelForCausalLM +from transformers.testing_utils import torch_device +from transformers.utils import is_peft_available + +from trl.models.activation_offloading import NoOpManager, OffloadActivations + +from .testing_utils import TrlTestCase, require_peft, require_torch_accelerator + + +if is_peft_available(): + from peft import LoraConfig, get_peft_model + + +class TestActivationOffloading(TrlTestCase): + @require_torch_accelerator + @require_peft + def test_offloading_with_peft_models(self) -> None: + """Test that activation offloading works with PEFT models.""" + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + model = AutoModelForCausalLM.from_pretrained(model_id).to(torch_device) + peft_config = LoraConfig( + lora_alpha=16, + lora_dropout=0.1, + r=8, + bias="none", + task_type="CAUSAL_LM", + ) + + model = get_peft_model(model, peft_config) + inp = torch.randint(0, 100, (2, 10), device=torch_device) + + # First forward-backward pass without offloading + torch.manual_seed(42) + loss = model(inp, labels=inp).loss + loss.backward() + + # Store gradients - only from trainable parameters + grads_original = [] + for name, param in model.named_parameters(): + if param.requires_grad and param.grad is not None: + grads_original.append((name, param.grad.clone())) + + # Reset gradients + for p in model.parameters(): + if p.grad is not None: + p.grad = None + + # Second forward-backward pass with offloading + torch.manual_seed(42) + with OffloadActivations(): + loss_c = model(inp, labels=inp).loss + loss_c.backward() + + # Compare gradients - only trainable parameters + for name_orig, grad_orig in grads_original: + for name_param, param in model.named_parameters(): + if name_param == name_orig and param.requires_grad and param.grad is not None: + ( + torch.testing.assert_close(grad_orig, param.grad, rtol=1e-4, atol=1e-5), + (f"Gradient mismatch for {name_orig}"), + ) + + @require_torch_accelerator + def test_noop_manager_with_offloading(self): + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + model = AutoModelForCausalLM.from_pretrained(model_id).to(torch_device) + inp = torch.randint(0, 100, (2, 10), device=torch_device) + + # Run with offloading but disable for specific section + with OffloadActivations(): + # First forward-backward with normal offloading + torch.manual_seed(42) + out1 = model(inp, labels=inp) + out1.loss.backward() + grads1 = [p.grad.clone() for p in model.parameters()] + + # Reset grads + for p in model.parameters(): + p.grad = None + + # Second forward-backward with NoOpManager + with NoOpManager(): + torch.manual_seed(42) + out2 = model(inp, labels=inp) + out2.loss.backward() + + grads2 = [p.grad.clone() for p in model.parameters()] + + # Gradients should match as NoOpManager should have prevented offloading + for g1, g2 in zip(grads1, grads2, strict=True): + torch.testing.assert_close(g1, g2, rtol=1e-4, atol=1e-5) + + @require_torch_accelerator + def test_min_offload_size(self): + """Test that tensors smaller than min_offload_size aren't offloaded""" + model = nn.Sequential( + nn.Linear(5, 5), # Small layer that shouldn't be offloaded + nn.Linear(5, 1000), # Large layer that should be offloaded + ).to(torch_device) + + inp = torch.randn(2, 5, device=torch_device) + + with OffloadActivations(min_offload_size=1000): + out = model(inp) + out.sum().backward() + + # The test passes if no errors occur, as we're mainly testing + # that the logic handles both offloaded and non-offloaded tensors + + @require_torch_accelerator + def test_real_hf_model(self): + """Test with an actual HuggingFace model""" + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + model = AutoModelForCausalLM.from_pretrained(model_id).to(torch_device) + + # Create small input + inp = torch.randint(0, 100, (2, 10), device=torch_device) + + # Baseline without offloading + torch.manual_seed(42) + out1 = model(inp, labels=inp).loss + out1.backward() + grads1 = [p.grad.clone() for p in model.parameters()] + + # Reset grads + for p in model.parameters(): + p.grad = None + + # With offloading + with OffloadActivations(): + torch.manual_seed(42) + out2 = model(inp, labels=inp).loss + out2.backward() + + grads2 = [p.grad.clone() for p in model.parameters()] + + # Check outputs and gradients match + torch.testing.assert_close(out1, out2) + for g1, g2 in zip(grads1, grads2, strict=True): + torch.testing.assert_close(g1, g2) + + @require_torch_accelerator + def test_tensor_deduplication(self): + """Test that deduplication works correctly for tensors sharing storage""" + + class ModelWithViews(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(100, 100) + + def forward(self, x): + out = self.linear(x) + view1 = out.view(-1) + view2 = out.transpose(0, 1) + return view1.sum() + view2.sum() + + model = ModelWithViews().to(torch_device) + offload_ctx = OffloadActivations(min_offload_size=1) + offload_ctx.update_model_params(model) + + x = torch.randn(10, 100, device=torch_device, requires_grad=True) + with offload_ctx: + loss = model(x) + + total_tensor_ids = offload_ctx.tensor_id + assert total_tensor_ids > 0, "Should have created tensor IDs" + + # modified=True means offloaded to CPU, modified=False means kept on GPU (deduplicated) + deduplicated_count = sum(1 for _, modified, _, _, _ in offload_ctx.tracker.values() if not modified) + offloaded_count = sum(1 for _, modified, _, _, _ in offload_ctx.tracker.values() if modified) + + assert offloaded_count > 0, "Should have offloaded at least one tensor" + assert deduplicated_count > 0, "Should have deduplicated at least one tensor (view)" + + unique_storages_offloaded = len(offload_ctx.storage_to_tensor_id) + assert unique_storages_offloaded < total_tensor_ids, ( + f"Deduplication should result in fewer storages ({unique_storages_offloaded}) " + f"than total tensors ({total_tensor_ids})" + ) + + loss.backward() + + @require_torch_accelerator + def test_parameter_filtering(self): + """Test that model parameters are filtered during offloading""" + model = nn.Sequential(nn.Linear(10, 20), nn.Linear(20, 10)).to(torch_device) + offload_ctx = OffloadActivations() + offload_ctx.update_model_params(model) + + assert len(offload_ctx.param_storages) > 0, "Should have tracked parameter storages" + + param_ptrs = {p.data.untyped_storage().data_ptr() for p in model.parameters()} + assert offload_ctx.param_storages == param_ptrs, "Tracked storages should match parameter storages" diff --git a/ICL/RL/trl_source/tests/test_chat_template_utils.py b/ICL/RL/trl_source/tests/test_chat_template_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a08976355ff56d6bb6b4163ba3c5fa8b1223deaa --- /dev/null +++ b/ICL/RL/trl_source/tests/test_chat_template_utils.py @@ -0,0 +1,283 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +import pytest +import transformers +from packaging.version import Version +from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer + +from trl import clone_chat_template +from trl.chat_template_utils import ( + add_response_schema, + get_training_chat_template, + is_chat_template_prefix_preserving, + parse_response, +) + +from .testing_utils import TrlTestCase, require_jmespath + + +class TestCloneChatTemplate(TrlTestCase): + def test_clone(self): + # This tokenizer doesn't have a chat_template by default + tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-BloomForCausalLM") + model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-BloomForCausalLM") + # This one has a chat_template by default + source = "trl-internal-testing/tiny-Qwen3ForCausalLM" + _, modified_tokenizer, _ = clone_chat_template(model, tokenizer, source) + + # Check if special tokens are correctly set + assert modified_tokenizer.eos_token == "<|im_end|>" + + def test_clone_with_resize(self): + # This tokenizer doesn't have a chat_template by default + tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-BloomForCausalLM") + model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-BloomForCausalLM") + # This one has a chat_template by default + source = "trl-internal-testing/tiny-Qwen3ForCausalLM" + modified_model, modified_tokenizer, _ = clone_chat_template( + model, tokenizer, source, resize_to_multiple_of=123 + ) + + # Check that the input embeddings have been resized to a multiple of 123 + assert (modified_model.vocab_size % 123) == 0 + # Check that the input embeddings size matches the tokenizer vocabulary size + assert model.vocab_size == len(modified_tokenizer.vocab) + + def test_clone_with_resize_and_extra_tokens_already_in_vocab(self): + # This tokenizer doesn't have a chat_template by default + tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-BloomForCausalLM") + model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-BloomForCausalLM") + # This one has a chat_template by default + source = "trl-internal-testing/tiny-Qwen3ForCausalLM" + # This will add , , ... to the tokenizer + modified_model, modified_tokenizer, _ = clone_chat_template( + model, tokenizer, source, resize_to_multiple_of=123 + ) + # Try if we can resize a tokenizer that already has extra these extra tokens + modified_model, modified_tokenizer, _ = clone_chat_template( + modified_model, modified_tokenizer, source, resize_to_multiple_of=124 + ) + + # Check that the input embeddings have been resized to a multiple of 123 + assert (modified_model.vocab_size % 124) == 0 + # Check that the input embeddings size matches the tokenizer vocabulary size + assert model.vocab_size == len(modified_tokenizer.vocab) + + def test_apply_new_chat_template(self): + # This tokenizer doesn't have a chat_template by default + tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-BloomForCausalLM") + model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-BloomForCausalLM") + # This one has a chat_template by default + source = "trl-internal-testing/tiny-Qwen3ForCausalLM" + _, modified_tokenizer, _ = clone_chat_template(model, tokenizer, source) + messages = [ + {"role": "system", "content": "You are helpful"}, + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi, how can I help you?"}, + ] + prompt = modified_tokenizer.apply_chat_template(messages, tokenize=False) + + assert ( + prompt + == "<|im_start|>system\nYou are helpful<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\n\n\n\n\nHi, how can I help you?<|im_end|>\n" + ) + + def test_clone_with_sequence_classification_model(self): + # This tokenizer doesn't have a chat_template by default + tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-GptNeoXForSequenceClassification") + model = AutoModelForSequenceClassification.from_pretrained( + "trl-internal-testing/tiny-GptNeoXForSequenceClassification" + ) + # This one has a chat_template by default + source = "trl-internal-testing/tiny-Qwen3ForCausalLM" + _, modified_tokenizer, _ = clone_chat_template(model, tokenizer, source) + + # Check if special tokens are correctly set + assert modified_tokenizer.eos_token == "<|im_end|>" + + +@require_jmespath +class TestAddResponseSchema: + @pytest.mark.xfail( + condition=Version(transformers.__version__) < Version("5.0.0"), + reason="Response parsing is not supported in transformers versions below 5.0.0", + strict=True, + ) + def test_add_response_schema(self): + tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen3MoeForSequenceClassification") + tokenizer = add_response_schema(tokenizer) + assistant_text = '\n{"name": "multiply", "arguments": {"a": 3, "b": 4}}\n<|im_end|>' + parsed = tokenizer.parse_response(assistant_text) + expected = { + "role": "assistant", + "content": "", + "tool_calls": [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 3, "b": 4}}}], + } + assert parsed == expected + + +class TestIsChatTemplatePrefixPreserving: + def test_prefix_preserving_template(self): + tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen3MoeForSequenceClassification") + tokenizer.chat_template = textwrap.dedent(r""" + {%- for message in messages %} + + {%- if message.role == 'user' %} + {{- '<|im_start|>user\n' + message.content + '<|im_end|>\n' }} + {%- elif message.role == 'assistant' %} + {{- '<|im_start|>assistant\n' + message.content + '<|im_end|>\n' }} + {%- endif %} + + {%- endfor %} + + {%- if add_generation_prompt %} + {{- '<|im_start|>assistant\n' }} + {%- endif %}""") + assert is_chat_template_prefix_preserving(tokenizer) is True + + def test_non_prefix_preserving_template(self): + tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen3MoeForSequenceClassification") + # The following template is quite typical of models like Qwen3 and GPT-OSS, where the thinking part is + # only present for last assistant message, which makes it non-prefix-preserving. + # docstyle-ignore + tokenizer.chat_template = textwrap.dedent(r""" + {%- if messages[0].role == 'system' %} + {{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }} + {%- endif %} + {%- set ns = namespace(last_query_index=messages|length - 1) %} + {%- for message in messages[::-1] %} + {%- set index = (messages|length - 1) - loop.index0 %} + {%- if message.role == "user" and message.content is string %} + {%- set ns.last_query_index = index %} + {%- break %} + {%- endif %} + {%- endfor %} + {%- for message in messages %} + {%- set content = message.content if message.content is string else '' %} + {%- if message.role == "user" or (message.role == "system" and not loop.first) %} + {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>\n' }} + {%- elif message.role == "assistant" %} + {%- set reasoning_content = '' %} + {%- if message.reasoning_content is string %} + {%- set reasoning_content = message.reasoning_content %} + {%- else %} + {%- if '' in content %} + {%- set reasoning_content = content.split('')[0].rstrip('\n').split('')[-1].lstrip('\n') %} + {%- set content = content.split('')[-1].lstrip('\n') %} + {%- endif %} + {%- endif %} + {%- if loop.index0 > ns.last_query_index %} + {%- if loop.last or (not loop.last and reasoning_content) %} + {{- '<|im_start|>' + message.role + '\n\n' + reasoning_content.strip('\n') + '\n\n\n' + content.lstrip('\n') }} + {%- else %} + {{- '<|im_start|>' + message.role + '\n' + content }} + {%- endif %} + {%- else %} + {{- '<|im_start|>' + message.role + '\n' + content }} + {%- endif %} + {{- '<|im_end|>\n' }} + {%- endif %} + {%- endfor %} + {%- if add_generation_prompt %} + {{- '<|im_start|>assistant\n' }} + {%- if enable_thinking is defined and enable_thinking is false %} + {{- '\n\n\n\n' }} + {%- endif %} + {%- endif %}""") + assert is_chat_template_prefix_preserving(tokenizer) is False + + +class TestGetTrainingChatTemplate: + def test_qwen3(self): + tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen3MoeForSequenceClassification") + assert is_chat_template_prefix_preserving(tokenizer) is False + tokenizer.chat_template = get_training_chat_template(tokenizer) + assert is_chat_template_prefix_preserving(tokenizer) is True + + +@pytest.mark.xfail( + condition=Version(transformers.__version__) < Version("5.0.0"), + reason="Tool parsing is not supported in transformers versions below 5.0.0", + strict=True, +) +@require_jmespath +class TestParseResponse: + def test_parse_response(self): + tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen3MoeForSequenceClassification") + tokenizer = add_response_schema(tokenizer) + # docstyle-ignore + text = textwrap.dedent("""\ + + {"name": "multiply", "arguments": {"a": 3, "b": 4}} + <|im_end|>""") + assistant_text = tokenizer(text)["input_ids"] + parsed = parse_response(tokenizer, assistant_text) + expected = { + "role": "assistant", + "content": "", + "tool_calls": [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 3, "b": 4}}}], + } + assert parsed == expected + + def test_parse_response_multiple_tool_calls(self): + tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen3MoeForSequenceClassification") + tokenizer = add_response_schema(tokenizer) + # docstyle-ignore + text = textwrap.dedent("""\ + + {"name": "multiply", "arguments": {"a": 3, "b": 4}} + + + {"name": "addition", "arguments": {"a": 3, "b": 4}} + <|im_end|>""") + assistant_text = tokenizer(text)["input_ids"] + parsed = parse_response(tokenizer, assistant_text) + expected = { + "role": "assistant", + "content": "", + "tool_calls": [ + {"type": "function", "function": {"name": "multiply", "arguments": {"a": 3, "b": 4}}}, + {"type": "function", "function": {"name": "addition", "arguments": {"a": 3, "b": 4}}}, + ], + } + assert parsed == expected + + def test_parse_response_no_tool_call(self): + tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen3MoeForSequenceClassification") + tokenizer = add_response_schema(tokenizer) + text = "Here is the answer to your question.<|im_end|>" + assistant_text = tokenizer(text)["input_ids"] + parsed = parse_response(tokenizer, assistant_text) + expected = { + "role": "assistant", + "content": "Here is the answer to your question.", + } + + assert parsed == expected + + def test_parse_response_malformed_tool_call(self): + tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen3MoeForSequenceClassification") + tokenizer = add_response_schema(tokenizer) + text = '\n{"name": "multiply", "arguments": {"a": 3, "b": 4}\n<|im_end|>' + assistant_text = tokenizer(text)["input_ids"] + parsed = parse_response(tokenizer, assistant_text) + expected = { + "role": "assistant", + "content": '\n{"name": "multiply", "arguments": {"a": 3, "b": 4}\n', + } + + assert parsed == expected diff --git a/ICL/RL/trl_source/tests/test_cli_utils.py b/ICL/RL/trl_source/tests/test_cli_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..cabb8f38bd2234ca3e698e01e0664b83a51f0813 --- /dev/null +++ b/ICL/RL/trl_source/tests/test_cli_utils.py @@ -0,0 +1,426 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tempfile +from dataclasses import dataclass +from unittest.mock import mock_open, patch + +import pytest +from datasets import DatasetDict, load_dataset + +from trl import DatasetMixtureConfig, TrlParser, get_dataset +from trl.scripts.utils import DatasetConfig + +from .testing_utils import TrlTestCase + + +@dataclass +class MyDataclass: + arg1: int + arg2: str = "default" + + +@dataclass +class InvalidDataclass: + config: str # This should raise an error in the TrlParser + + +class TestTrlParser(TrlTestCase): + def test_init_without_config_field(self): + """Test initialization without 'config' field in the dataclasses.""" + parser = TrlParser(dataclass_types=[MyDataclass]) + assert isinstance(parser, TrlParser) + + def test_init_with_config_field(self): + """Test initialization with a 'config' field in the dataclass (should raise ValueError).""" + with pytest.raises(ValueError, match="has a field named 'config'"): + TrlParser(dataclass_types=[InvalidDataclass]) + + @patch("builtins.open", mock_open(read_data="env:\n VAR1: value1\n VAR2: value2\narg1: 2")) + @patch("yaml.safe_load") + @patch("os.environ", new_callable=dict) # Mock os.environ as a dictionary + def test_parse_args_and_config_with_valid_config(self, mock_environ, mock_yaml_load): + """Test parse_args_and_config method with valid arguments and config.""" + mock_yaml_load.return_value = {"env": {"VAR1": "value1", "VAR2": "value2"}, "arg1": 2} + + parser = TrlParser(dataclass_types=[MyDataclass]) + + args = ["--arg2", "value", "--config", "config.yaml"] # don't set arg1 to test default value + + # Simulate the config being loaded and environment variables being set + result_args = parser.parse_args_and_config(args) + + # Set the environment variables using the mock + mock_environ["VAR1"] = "value1" + mock_environ["VAR2"] = "value2" + + # Ensure that the environment variables were set correctly + assert mock_environ.get("VAR1") == "value1" + assert mock_environ.get("VAR2") == "value2" + + # Check the parsed arguments + assert len(result_args) == 1 + assert isinstance(result_args[0], MyDataclass) + assert result_args[0].arg1 == 2 + assert result_args[0].arg2 == "value" + + @patch("builtins.open", mock_open(read_data="arg1: 2")) + @patch("yaml.safe_load") + def test_parse_args_and_arg_override_config(self, mock_yaml_load): + """Test parse_args_and_config method and check that arguments override the config.""" + mock_yaml_load.return_value = {"arg1": 2} # this arg is meant to be overridden + + parser = TrlParser(dataclass_types=[MyDataclass]) + + args = ["--arg1", "3", "--config", "config.yaml"] # override arg1 default with 3 + + # Simulate the config being loaded and arguments being passed + result_args = parser.parse_args_and_config(args) + + # Check the parsed arguments + assert len(result_args) == 1 + assert isinstance(result_args[0], MyDataclass) + assert result_args[0].arg1 == 3 + + @patch("builtins.open", mock_open(read_data="env: not_a_dict")) + @patch("yaml.safe_load") + def test_parse_args_and_config_with_invalid_env(self, mock_yaml_load): + """Test parse_args_and_config method when the 'env' field is not a dictionary.""" + mock_yaml_load.return_value = {"env": "not_a_dict"} + + parser = TrlParser(dataclass_types=[MyDataclass]) + + args = ["--arg1", "2", "--arg2", "value", "--config", "config.yaml"] + + with pytest.raises(ValueError, match="`env` field should be a dict in the YAML file."): + parser.parse_args_and_config(args) + + def test_parse_args_and_config_without_config(self): + """Test parse_args_and_config without the `--config` argument.""" + parser = TrlParser(dataclass_types=[MyDataclass]) + + args = ["--arg1", "2", "--arg2", "value"] + + # Simulate no config, just parse args normally + result_args = parser.parse_args_and_config(args) + + # Check that the arguments are parsed as is + assert len(result_args) == 1 + assert isinstance(result_args[0], MyDataclass) + assert result_args[0].arg1 == 2 + assert result_args[0].arg2 == "value" + + def test_set_defaults_with_config(self): + """Test set_defaults_with_config updates the defaults.""" + parser = TrlParser(dataclass_types=[MyDataclass]) + + # Update defaults + parser.set_defaults_with_config(arg1=42) + + # Ensure the default value is updated + result_args = parser.parse_args_and_config([]) + assert len(result_args) == 1 + assert isinstance(result_args[0], MyDataclass) + assert result_args[0].arg1 == 42 + + def test_parse_args_and_config_with_remaining_strings(self): + parser = TrlParser(dataclass_types=[MyDataclass]) + + args = ["--arg1", "2", "--arg2", "value", "remaining"] + + # Simulate no config, just parse args normally + result_args = parser.parse_args_and_config(args, return_remaining_strings=True) + + # Check that the arguments are parsed as is + assert len(result_args) == 2 + assert isinstance(result_args[0], MyDataclass) + assert result_args[0].arg1 == 2 + assert result_args[0].arg2 == "value" + assert result_args[1] == ["remaining"] + + @patch("builtins.open", mock_open(read_data="remaining_string_in_config: abc")) + @patch("yaml.safe_load") + def test_parse_args_and_config_with_remaining_strings_in_config_and_args(self, mock_yaml_load): + mock_yaml_load.return_value = {"remaining_string_in_config": "abc"} + + parser = TrlParser(dataclass_types=[MyDataclass]) + + args = ["--arg1", "2", "--remaining_string_in_args", "def", "--config", "config.yaml"] + + # Simulate the config being loaded and arguments being passed + result_args = parser.parse_args_and_config(args, return_remaining_strings=True) + + # Check that the arguments are parsed as is + assert len(result_args) == 2 + assert isinstance(result_args[0], MyDataclass) + assert result_args[0].arg1 == 2 + assert result_args[1] == ["--remaining_string_in_config", "abc", "--remaining_string_in_args", "def"] + + @patch("builtins.open", mock_open(read_data="arg1: 2\narg2: config_value")) + @patch("yaml.safe_load") + def test_subparsers_with_config_defaults(self, mock_yaml_load): + """Test that config defaults are applied to all subparsers.""" + mock_yaml_load.return_value = {"arg1": 2, "arg2": "config_value"} + + # Create the main parser + parser = TrlParser() + + # Add subparsers + subparsers = parser.add_subparsers(dest="command", parser_class=TrlParser) + + # Create a subparser for a specific command + subparsers.add_parser("subcommand", dataclass_types=[MyDataclass]) + + # Parse with config file + args = ["subcommand", "--config", "config.yaml"] + result_args = parser.parse_args_and_config(args) + + # Check main parser arguments + assert len(result_args) == 1 + + # Check that config values were applied to the subparser + assert result_args[0].arg1 == 2 # Default from config + assert result_args[0].arg2 == "config_value" # Default from config + + @patch("builtins.open", mock_open(read_data="arg1: 2\narg2: config_value")) + @patch("yaml.safe_load") + def test_subparsers_with_config_defaults_and_arg_override(self, mock_yaml_load): + """Test that config defaults are applied to all subparsers.""" + mock_yaml_load.return_value = {"arg1": 2, "arg2": "config_value"} + + # Create the main parser + parser = TrlParser() + + # Add subparsers + subparsers = parser.add_subparsers(dest="command", parser_class=TrlParser) + + # Create a subparser for a specific command + subparsers.add_parser("subcommand", dataclass_types=[MyDataclass]) + + # Test with command line arguments overriding config + args = ["subcommand", "--arg1", "3", "--config", "config.yaml"] + result_args = parser.parse_args_and_config(args) + + # Command line arguments should override config + assert result_args[0].arg1 == 3 + assert result_args[0].arg2 == "config_value" # Still from config + + @patch("builtins.open", mock_open(read_data="arg1: 2\nthis_arg_does_not_exist: config_value")) + @patch("yaml.safe_load") + def test_subparsers_with_config_defaults_and_arg_override_wrong_name(self, mock_yaml_load): + """Test that config defaults are applied to all subparsers.""" + mock_yaml_load.return_value = {"arg1": 2, "this_arg_does_not_exist": "config_value"} + + # Create the main parser + parser = TrlParser() + + # Add subparsers + subparsers = parser.add_subparsers(dest="command", parser_class=TrlParser) + + # Create a subparser for a specific command + subparsers.add_parser("subcommand", dataclass_types=[MyDataclass]) + + # Test with command line arguments overriding config + args = ["subcommand", "--arg1", "3", "--config", "config.yaml"] + with pytest.raises(ValueError): + parser.parse_args_and_config(args) + + parser.parse_args_and_config(args, fail_with_unknown_args=False) + + @patch("builtins.open", mock_open(read_data="arg1: 2\narg2: config_value")) + @patch("yaml.safe_load") + def test_subparsers_multiple_with_config_defaults(self, mock_yaml_load): + """Test that config defaults are applied to all subparsers.""" + mock_yaml_load.return_value = {"arg1": 2, "arg2": "config_value"} + + # Create the main parser + parser = TrlParser() + + # Add subparsers + subparsers = parser.add_subparsers(dest="command", parser_class=TrlParser) + + # Create a subparser for a specific command + subparsers.add_parser("subcommand0", dataclass_types=[MyDataclass]) + subparsers.add_parser("subcommand1", dataclass_types=[MyDataclass]) + + for idx in range(2): + # Parse with config file + args = [f"subcommand{idx}", "--config", "config.yaml"] + result_args = parser.parse_args_and_config(args) + + # Check main parser arguments + assert len(result_args) == 1 + + # Check that config values were applied to the subparser + assert result_args[0].arg1 == 2 # Default from config + assert result_args[0].arg2 == "config_value" # Default from config + + +class TestGetDataset: + def test_single_dataset_with_config(self): + mixture_config = DatasetMixtureConfig( + datasets=[DatasetConfig(path="trl-internal-testing/zen", name="standard_language_modeling")] + ) + result = get_dataset(mixture_config) + expected = load_dataset("trl-internal-testing/zen", "standard_language_modeling") + assert expected["train"][:] == result["train"][:] + + def test_single_dataset_preference_config(self): + mixture_config = DatasetMixtureConfig( + datasets=[DatasetConfig(path="trl-internal-testing/zen", name="standard_preference")] + ) + result = get_dataset(mixture_config) + expected = load_dataset("trl-internal-testing/zen", "standard_preference") + assert expected["train"][:] == result["train"][:] + + def test_single_dataset_streaming(self): + mixture_config = DatasetMixtureConfig( + datasets=[DatasetConfig(path="trl-internal-testing/zen", name="standard_language_modeling")], + streaming=True, + ) + result = get_dataset(mixture_config) + expected = load_dataset("trl-internal-testing/zen", "standard_language_modeling") + assert expected["train"].to_list() == list(result["train"]) + + def test_dataset_mixture_basic(self): + dataset_config1 = DatasetConfig( + path="trl-internal-testing/zen", name="standard_prompt_completion", split="train", columns=["prompt"] + ) + dataset_config2 = DatasetConfig( + path="trl-internal-testing/zen", name="standard_preference", split="train", columns=["prompt"] + ) + mixture_config = DatasetMixtureConfig(datasets=[dataset_config1, dataset_config2]) + result = get_dataset(mixture_config) + assert isinstance(result, DatasetDict) + assert "train" in result + train_dataset = result["train"] + assert train_dataset.column_names == ["prompt"] + prompts = train_dataset["prompt"] + expected_first_half = load_dataset("trl-internal-testing/zen", "standard_preference", split="train") + assert prompts[: len(prompts) // 2] == expected_first_half["prompt"] + expected_second_half = load_dataset("trl-internal-testing/zen", "standard_prompt_completion", split="train") + assert prompts[len(prompts) // 2 :] == expected_second_half["prompt"] + + def test_dataset_mixture_with_weights(self): + dataset_config1 = DatasetConfig( + path="trl-internal-testing/zen", name="standard_prompt_completion", split="train[:50%]", columns=["prompt"] + ) + dataset_config2 = DatasetConfig( + path="trl-internal-testing/zen", name="standard_preference", split="train[:50%]", columns=["prompt"] + ) + mixture_config = DatasetMixtureConfig(datasets=[dataset_config1, dataset_config2]) + result = get_dataset(mixture_config) + assert isinstance(result, DatasetDict) + assert "train" in result + train_dataset = result["train"] + assert train_dataset.column_names == ["prompt"] + prompts = train_dataset["prompt"] + expected_first_half = load_dataset("trl-internal-testing/zen", "standard_preference", split="train[:50%]") + assert prompts[: len(prompts) // 2] == expected_first_half["prompt"] + expected_second_half = load_dataset( + "trl-internal-testing/zen", "standard_prompt_completion", split="train[:50%]" + ) + assert prompts[len(prompts) // 2 :] == expected_second_half["prompt"] + + def test_dataset_mixture_with_test_split(self): + mixture_config = DatasetMixtureConfig( + datasets=[DatasetConfig(path="trl-internal-testing/zen", name="standard_language_modeling")], + test_split_size=2, + ) + result = get_dataset(mixture_config) + assert isinstance(result, DatasetDict) + assert "train" in result + assert "test" in result + assert len(result["train"]) == 15 + assert len(result["test"]) == 2 + + def test_empty_dataset_mixture_raises_error(self): + mixture_config = DatasetMixtureConfig(datasets=[]) + + with pytest.raises(ValueError, match="No datasets were loaded"): + get_dataset(mixture_config) + + def test_mixture_multiple_different_configs(self): + dataset_config1 = DatasetConfig( + path="trl-internal-testing/zen", name="conversational_preference", split="train", columns=["prompt"] + ) + dataset_config2 = DatasetConfig( + path="trl-internal-testing/zen", name="conversational_prompt_only", split="test" + ) + mixture_config = DatasetMixtureConfig(datasets=[dataset_config1, dataset_config2]) + result = get_dataset(mixture_config) + assert isinstance(result, DatasetDict) + assert "train" in result + assert len(result["train"]) > 0 + + def test_trlparser_parses_yaml_config_correctly(self): + # Prepare YAML content exactly like your example + # docstyle-ignore + yaml_content = """ + datasets: + - path: trl-internal-testing/zen + name: standard_prompt_only + - path: trl-internal-testing/zen + name: standard_preference + columns: + - prompt + """ + + # Write YAML to a temporary file + with tempfile.NamedTemporaryFile("w+", suffix=".yaml") as tmpfile: + tmpfile.write(yaml_content) + tmpfile.flush() + parser = TrlParser((DatasetMixtureConfig,)) + args = parser.parse_args_and_config(args=["--config", tmpfile.name])[0] + + # Assert that we got DatasetMixtureConfig instance + assert isinstance(args, DatasetMixtureConfig) + + # Assert datasets list length + assert len(args.datasets) == 2 + + # Check first dataset + dataset_config1 = args.datasets[0] + assert isinstance(dataset_config1, DatasetConfig) + assert dataset_config1.path == "trl-internal-testing/zen" + assert dataset_config1.name == "standard_prompt_only" + assert dataset_config1.columns is None # No columns specified + + # Check second dataset + dataset_config2 = args.datasets[1] + assert isinstance(dataset_config2, DatasetConfig) + assert dataset_config2.path == "trl-internal-testing/zen" + assert dataset_config2.name == "standard_preference" + assert dataset_config2.columns == ["prompt"] # Columns specified + + def test_trlparser_parses_yaml_and_loads_dataset(self): + # Prepare YAML content exactly like your example + # docstyle-ignore + yaml_content = """ + datasets: + - path: trl-internal-testing/zen + name: standard_language_modeling + """ + + # Write YAML to a temporary file + with tempfile.NamedTemporaryFile("w+", suffix=".yaml") as tmpfile: + tmpfile.write(yaml_content) + tmpfile.flush() + parser = TrlParser((DatasetMixtureConfig,)) + args = parser.parse_args_and_config(args=["--config", tmpfile.name])[0] + + # Load the dataset using get_dataset + result = get_dataset(args) + expected = load_dataset("trl-internal-testing/zen", "standard_language_modeling") + assert expected["train"][:] == result["train"][:] diff --git a/ICL/RL/trl_source/tests/test_data_utils.py b/ICL/RL/trl_source/tests/test_data_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..111e52b23ef2442c773e96330fce3601a12dfefa --- /dev/null +++ b/ICL/RL/trl_source/tests/test_data_utils.py @@ -0,0 +1,1198 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import textwrap +from time import strftime + +import pytest +from datasets import Dataset, DatasetDict +from transformers import AutoProcessor, AutoTokenizer, is_vision_available + +from trl.data_utils import ( + apply_chat_template, + extract_prompt, + is_conversational, + is_conversational_from_value, + maybe_apply_chat_template, + maybe_convert_to_chatml, + maybe_extract_prompt, + maybe_unpair_preference_dataset, + pack_dataset, + prepare_multimodal_messages, + prepare_multimodal_messages_vllm, + truncate_dataset, + unpair_preference_dataset, +) + +from .testing_utils import TrlTestCase, require_vision + + +if is_vision_available(): + from PIL import Image + + +@require_vision +class TestPrepareMultimodalMessages: + def test_basic_user_assistant_conversation(self): + """Test basic conversation with user and assistant messages.""" + messages = [ + {"role": "user", "content": "What color is the sky?"}, + {"role": "assistant", "content": "It is blue."}, + ] + image = Image.new("RGB", (10, 10), color="blue") + messages = prepare_multimodal_messages(messages, images=[image]) + + expected = [ + { + "role": "user", + "content": [{"type": "image", "image": image}, {"type": "text", "text": "What color is the sky?"}], + }, + { + "role": "assistant", + "content": [{"type": "text", "text": "It is blue."}], + }, + ] + + assert messages == expected + + def test_first_user_message_gets_image(self): + """Test that only the first user message gets an image.""" + messages = [ + {"role": "user", "content": "What color is the sky?"}, + {"role": "assistant", "content": "It is blue."}, + {"role": "user", "content": "How about the grass?"}, + ] + + image = Image.new("RGB", (10, 10), color="blue") + messages = prepare_multimodal_messages(messages, images=[image]) + + expected = [ + { + "role": "user", + "content": [{"type": "image", "image": image}, {"type": "text", "text": "What color is the sky?"}], + }, + { + "role": "assistant", + "content": [{"type": "text", "text": "It is blue."}], + }, + { + "role": "user", + "content": [{"type": "text", "text": "How about the grass?"}], + }, + ] + + assert messages == expected + + def test_multiple_images(self): + """Test that multiple images are added to the first user message.""" + messages = [ + {"role": "user", "content": "What color is the sky?"}, + {"role": "assistant", "content": "It is blue."}, + ] + images = [Image.new("RGB", (10, 10), color=color) for color in ["red", "green", "blue"]] + messages = prepare_multimodal_messages(messages, images=images) + + expected = [ + { + "role": "user", + "content": [ + {"type": "image", "image": images[0]}, + {"type": "image", "image": images[1]}, + {"type": "image", "image": images[2]}, + {"type": "text", "text": "What color is the sky?"}, + ], + }, + { + "role": "assistant", + "content": [{"type": "text", "text": "It is blue."}], + }, + ] + + assert messages == expected + + def test_system_message_transformation(self): + """Test that system messages are properly transformed.""" + messages = [ + {"role": "system", "content": "You are a helpful assistant"}, + {"role": "user", "content": "What color is the sky?"}, + ] + + image = Image.new("RGB", (10, 10), color="blue") + messages = prepare_multimodal_messages(messages, images=[image]) + + expected = [ + { + "role": "system", + "content": [{"type": "text", "text": "You are a helpful assistant"}], + }, + { + "role": "user", + "content": [{"type": "image", "image": image}, {"type": "text", "text": "What color is the sky?"}], + }, + ] + + assert messages == expected + + def test_already_prepared_messages_unchanged(self): + """Test that messages with list content are not modified.""" + messages = [ + {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant"}]}, + {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]}, + {"role": "assistant", "content": [{"type": "text", "text": "It is blue."}]}, + ] + + image = Image.new("RGB", (10, 10), color="blue") + messages = prepare_multimodal_messages(messages, images=[image]) + + expected = [ + { + "role": "system", + "content": [{"type": "text", "text": "You are a helpful assistant"}], + }, + { + "role": "user", + "content": [{"type": "image", "image": image}, {"type": "text", "text": "What color is the sky?"}], + }, + { + "role": "assistant", + "content": [{"type": "text", "text": "It is blue."}], + }, + ] + + assert messages == expected + + def test_mixed_prepared_and_unprepared_messages(self): + """Test handling of mixed prepared and unprepared messages.""" + messages = [ + {"role": "user", "content": "What color is the sky?"}, + {"role": "assistant", "content": [{"type": "text", "text": "It is blue."}]}, + {"role": "user", "content": "What about the grass?"}, + ] + + image = Image.new("RGB", (10, 10), color="blue") + messages = prepare_multimodal_messages(messages, images=[image]) + + expected = [ + { + "role": "user", + "content": [{"type": "image", "image": image}, {"type": "text", "text": "What color is the sky?"}], + }, + { + "role": "assistant", + "content": [{"type": "text", "text": "It is blue."}], + }, + { + "role": "user", + "content": [{"type": "text", "text": "What about the grass?"}], + }, + ] + + assert messages == expected + + +@require_vision +class TestPrepareMultimodalMessagesVLLM: + def test_single_image_conversion(self): + messages = [ + { + "role": "user", + "content": [ + {"type": "image", "image": Image.new("RGB", (10, 10), color="blue")}, + {"type": "text", "text": "What color is the sky?"}, + ], + } + ] + + result = prepare_multimodal_messages_vllm(messages) + + # Original should remain unchanged (deepcopy test) + assert messages[0]["content"][0]["type"] == "image" + + # Converted version should have correct structure + assert result[0]["content"][0]["type"] == "image_pil" + assert "image_pil" in result[0]["content"][0] + assert "image" not in result[0]["content"][0] + assert isinstance(result[0]["content"][0]["image_pil"], Image.Image) + assert result[0]["content"][1]["type"] == "text" + + def test_mixed_content_conversion(self): + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "What color is the sky?"}, + {"type": "image", "image": Image.new("RGB", (10, 10), color="blue")}, + ], + } + ] + + result = prepare_multimodal_messages_vllm(messages) + + # The image part should be converted, text should be unchanged + assert result[0]["content"][0]["type"] == "text" + assert result[0]["content"][1]["type"] == "image_pil" + + def test_no_images(self): + messages = [{"role": "user", "content": [{"type": "text", "text": "What color is the sky?"}]}] + + result = prepare_multimodal_messages_vllm(messages) + + # Should be identical since there are no images + assert result == messages + # And a deepcopy — not the same object + assert result is not messages + assert result[0] is not messages[0] + + def test_multiple_messages(self): + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "What color is the sky?"}, + {"type": "image", "image": Image.new("RGB", (10, 10), color="blue")}, + ], + }, + { + "role": "assistant", + "content": [{"type": "text", "text": "It is blue."}], + }, + ] + + result = prepare_multimodal_messages_vllm(messages) + + assert result[0]["content"][1]["type"] == "image_pil" + assert result[1]["content"][0]["type"] == "text" + assert result[1]["content"][0]["text"] == "It is blue." + + def test_deepcopy_integrity(self): + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "What color is the sky?"}, + {"type": "image", "image": Image.new("RGB", (10, 10), color="blue")}, + ], + }, + ] + original = copy.deepcopy(messages) + + _ = prepare_multimodal_messages_vllm(messages) + + # Original should not be mutated + assert messages == original + + +class TestIsConversational(TrlTestCase): + # fmt: off + conversational_examples = [ + { # Language modeling + "messages": [ + {"role": "user", "content": "What color is the sky?"}, + {"role": "assistant", "content": "It is blue."}, + ], + }, + { # Prompt-only + "prompt": [{"role": "user", "content": "What color is the sky?"}], + }, + { # Prompt-completion + "prompt": [{"role": "user", "content": "What color is the sky?"}], + "completion": [{"role": "assistant", "content": "It is blue."}], + }, + { # Preference + "prompt": [{"role": "user", "content": "What color is the sky?"}], + "chosen": [{"role": "assistant", "content": "It is blue."}], + "rejected": [{"role": "assistant", "content": "It is green."}], + }, + { # Preference with implicit prompt + "chosen": [ + {"role": "user", "content": "What color is the sky?"}, + {"role": "assistant", "content": "It is blue."}, + ], + "rejected": [ + {"role": "user", "content": "What color is the sky?"}, + {"role": "assistant", "content": "It is green."}, + ], + }, + { # Preference with tool calls + "prompt": [{"role": "user", "content": "What color is the sky?"}], + "chosen": [ + {"role": "assistant", "tool_calls": [{"type": "function", "function": {"name": "get_color", "arguments": {"what": "sky"}}}]}, + {"role": "tool", "name": "get_color", "content": "blue"}, + {"role": "assistant", "content": "It is blue."}, + ], + "rejected": [ + {"role": "assistant", "tool_calls": [{"type": "function", "function": {"name": "get_color", "arguments": {"what": "tree"}}}]}, + {"role": "tool", "name": "get_color", "content": "green"}, + {"role": "assistant", "content": "It is green."}, + ], + "tools": [ + { + "type": "function", + "function": { + "description": "Gets the color.", + "name": "get_color", + "parameters": {"properties": {"what": {"description": "What to get the color of.", "type": "string"}}, "required": ["what"], "type": "object"}, + "return": {"description": "The color.", "type": "string"}, + }, + }, + ], + }, + { # Unpaired preference + "prompt": [{"role": "user", "content": "What color is the sky?"}], + "completion": [{"role": "assistant", "content": "It is blue."}], + "label": True, + }, + { # Language modeling with harmony + "messages": [ + {"role": "system", "content": "Respond in a friendly manner."}, + {"role": "user", "content": "What color is the sky?"}, + {"role": "assistant", "thinking": "The user asks the color of the sky...", "content": "It is blue."}, + ], + }, + { # Prompt-only with harmony + "prompt": [ + {"role": "system", "content": "Respond in a friendly manner."}, + {"role": "user", "content": "What color is the sky?"}, + ], + }, + { # Prompt-completion with harmony + "prompt": [ + {"role": "system", "content": "Respond in a friendly manner."}, + {"role": "user", "content": "What color is the sky?"}, + ], + "completion": [ + {"role": "assistant", "thinking": "The user asks the color of the sky...", "content": "It is blue."}, + ], + }, + { # Preference with harmony + "prompt": [ + {"role": "system", "content": "Respond in a friendly manner."}, + {"role": "user", "content": "What color is the sky?"}, + ], + "chosen": [ + {"role": "assistant", "thinking": "The user asks the color of the sky...", "content": "It is blue."}, + ], + "rejected": [ + {"role": "assistant", "thinking": "The user asks the color of the tree...", "content": "It is green."}, + ], + }, + { # Preference with implicit prompt and harmony + "chosen": [ + {"role": "system", "content": "Respond in a friendly manner."}, + {"role": "user", "content": "What color is the sky?"}, + {"role": "assistant", "thinking": "The user asks the color of the sky...", "content": "It is blue."}, + ], + "rejected": [ + {"role": "system", "content": "Respond in a friendly manner."}, + {"role": "user", "content": "What color is the sky?"}, + {"role": "assistant", "thinking": "The user asks the color of the tree...", "content": "It is green."}, + ], + }, + { # Unpaired preference with harmony + "prompt": [ + {"role": "system", "content": "Respond in a friendly manner."}, + {"role": "user", "content": "What color is the sky?"}, + ], + "completion": [ + {"role": "assistant", "thinking": "The user asks the color of the sky...", "content": "It is blue."}, + ], + "label": True, + }, + ] + # fmt: on + + non_conversational_examples = [ + {"prompt": "The sky is", "completion": " blue."}, + {"text": "The sky is blue."}, + {"prompt": "The sky is"}, + {"prompt": "The sky is", "chosen": " blue.", "rejected": " green."}, + {"prompt": "The sky is", "completion": " blue.", "label": True}, + ] + + @pytest.mark.parametrize("example", conversational_examples) + def test_conversational(self, example): + assert is_conversational(example) + + @pytest.mark.parametrize("example", non_conversational_examples) + def test_non_conversational(self, example): + assert not is_conversational(example) + + +class TestIsConversationalFromValue(TrlTestCase): + def test_positive_1(self): + example = { + "conversations": [ + {"from": "user", "value": "What color is the sky?"}, + {"from": "assistant", "value": "It is blue."}, + ], + } + assert is_conversational_from_value(example) + + def test_negative_1(self): + example = { + "messages": [ + {"role": "user", "content": "What color is the sky?"}, + {"role": "assistant", "content": "It is blue."}, + ], + } + assert not is_conversational_from_value(example) + + def test_negative_2(self): + example = {"text": "The sky is blue."} + assert not is_conversational_from_value(example) + + +class TestApplyChatTemplate(TrlTestCase): + tokenizers = [ + "trl-internal-testing/tiny-CohereForCausalLM", + "trl-internal-testing/tiny-DeepseekV3ForCausalLM", + "trl-internal-testing/tiny-DeepseekV3ForCausalLM-0528", + "trl-internal-testing/tiny-FalconMambaForCausalLM", + "trl-internal-testing/tiny-Gemma2ForCausalLM", + "trl-internal-testing/tiny-GemmaForCausalLM", + "trl-internal-testing/tiny-GptOssForCausalLM", + "trl-internal-testing/tiny-LlamaForCausalLM-3.1", + "trl-internal-testing/tiny-LlamaForCausalLM-3.2", + "trl-internal-testing/tiny-LlamaForCausalLM-3", + "trl-internal-testing/tiny-MistralForCausalLM-0.1", + "trl-internal-testing/tiny-MistralForCausalLM-0.2", + "trl-internal-testing/tiny-Phi3ForCausalLM", + "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + "trl-internal-testing/tiny-Qwen3ForCausalLM", + ] + + conversational_examples = [ + { # Language modeling + "messages": [ + {"role": "user", "content": "What color is the sky?"}, + {"role": "assistant", "content": "It is blue."}, + ], + }, + { # Prompt-only + "prompt": [{"role": "user", "content": "What color is the sky?"}], + }, + { # Prompt-completion + "prompt": [{"role": "user", "content": "What color is the sky?"}], + "completion": [{"role": "assistant", "content": "It is blue."}], + }, + { # Preference + "prompt": [{"role": "user", "content": "What color is the sky?"}], + "chosen": [{"role": "assistant", "content": "It is blue."}], + "rejected": [{"role": "assistant", "content": "It is green."}], + }, + { # Preference with implicit prompt + "chosen": [ + {"role": "user", "content": "What color is the sky?"}, + {"role": "assistant", "content": "It is blue."}, + ], + "rejected": [ + {"role": "user", "content": "What color is the sky?"}, + {"role": "assistant", "content": "It is green."}, + ], + }, + { # Unpaired preference + "prompt": [{"role": "user", "content": "What color is the sky?"}], + "completion": [{"role": "assistant", "content": "It is blue."}], + "label": True, + }, + ] + + non_conversational_examples = [ + {"text": "The sky is blue."}, # Language modeling + {"prompt": "The sky is"}, # Prompt-only + {"prompt": "The sky is", "completion": " blue."}, # Prompt-completion + {"prompt": "The sky is", "chosen": " blue.", "rejected": " green."}, # Preference + {"chosen": "The sky is blue.", "rejected": "The sky is green."}, # Preference with implicit prompt + {"prompt": "The sky is", "completion": " blue.", "label": True}, # Unpaired preference + ] + + @pytest.mark.parametrize("example", conversational_examples) + @pytest.mark.parametrize("tokenizer_id", tokenizers) + def test_apply_chat_template(self, tokenizer_id, example): + tokenizer = AutoTokenizer.from_pretrained(tokenizer_id) + result = apply_chat_template(example, tokenizer) + + # Checking if the result is a dictionary + assert isinstance(result, dict) + + # The chat template should be applied to the following keys + for key in ["prompt", "chosen", "rejected", "completion"]: + if key in example: + assert key in result + assert isinstance(result[key], str) + + # Exception for messages, the key is "text" once the chat template is applied + if "messages" in example: + assert "text" in result + assert isinstance(result["text"], str) + + # The label should be kept + if "label" in example: + assert "label" in result + assert isinstance(result["label"], bool) + assert result["label"] == example["label"] + + # both conversational and non-conversational examples + @pytest.mark.parametrize("example", conversational_examples + non_conversational_examples) + @pytest.mark.parametrize("tokenizer_id", tokenizers) + def test_maybe_apply_chat_template(self, tokenizer_id, example): + tokenizer = AutoTokenizer.from_pretrained(tokenizer_id) + result = maybe_apply_chat_template(example, tokenizer) + + # Checking if the result is a dictionary + assert isinstance(result, dict) + + # The chat template should be applied to the following keys + for key in ["prompt", "chosen", "rejected", "completion"]: + if key in example: + assert key in result + assert isinstance(result[key], str) + + # Exception for messages, the key is "text" once the chat template is applied + if "messages" in example: + assert "text" in result + assert isinstance(result["text"], str) + + # The label should be kept + if "label" in example: + assert "label" in result + assert isinstance(result["label"], bool) + assert result["label"] == example["label"] + + def test_apply_chat_template_with_chat_template_kwargs(self): + tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen3ForCausalLM") + + example = { + "prompt": [{"role": "user", "content": "What color is the sky?"}], + # with this tokenizer, when you pass enable_thinking=False, it will add "\n\n\n\n" + "chat_template_kwargs": {"enable_thinking": False}, + } + result = apply_chat_template(example, tokenizer) + + # docstyle-ignore + expected = textwrap.dedent("""\ + <|im_start|>user + What color is the sky?<|im_end|> + <|im_start|>assistant + + + + + """) + + assert result["prompt"] == expected + + def test_apply_chat_template_with_tools(self): + tokenizer = AutoProcessor.from_pretrained("trl-internal-testing/tiny-LlamaForCausalLM-3.2") + + # Define dummy test tools + def get_current_temperature(location: str): + """ + Gets the temperature at a given location. + + Args: + location: The location to get the temperature for + """ + return 22.0 + + # Define test case + test_case = { + "prompt": [ + {"content": "What's the temperature in London?", "role": "user"}, + ] + } + # Test with tools + result_with_tools = apply_chat_template(test_case, tokenizer, tools=[get_current_temperature]) + + # Verify tools are included in the output + assert "get_current_temperature" in result_with_tools["prompt"] + + # Test without tools + result_without_tools = apply_chat_template(test_case, tokenizer, tools=None) + + # Verify tools are not included in the output + assert "get_current_temperature" not in result_without_tools["prompt"] + + +class TestApplyChatTemplateHarmony(TrlTestCase): + def test_language_modeling(self): + messages = { + "messages": [ + {"role": "system", "content": "Respond in a friendly manner."}, + {"role": "user", "content": "What color is the sky?"}, + {"role": "assistant", "thinking": "The user asks the color of the sky...", "content": "It is blue."}, + ], + } + output = apply_chat_template( + messages, + tokenizer=AutoTokenizer.from_pretrained("trl-internal-testing/tiny-GptOssForCausalLM"), + reasoning_effort="low", + model_identity="You are HuggingGPT.", + ) + + # docstyle-ignore + expected = textwrap.dedent(f"""\ + <|start|>system<|message|>You are HuggingGPT. + Knowledge cutoff: 2024-06 + Current date: {strftime("%Y-%m-%d")} + + Reasoning: low + + # Valid channels: analysis, commentary, final. Channel must be included for every message.<|end|><|start|>developer<|message|># Instructions + + Respond in a friendly manner. + + <|end|><|start|>user<|message|>What color is the sky?<|end|><|start|>assistant<|channel|>analysis<|message|>The user asks the color of the sky...<|end|><|start|>assistant<|channel|>final<|message|>It is blue.<|return|>""") + + assert output["text"] == expected + + def test_prompt_only(self): + messages = { + "prompt": [ + {"role": "system", "content": "Respond in a friendly manner."}, + {"role": "user", "content": "What color is the sky?"}, + ], + } + output = apply_chat_template( + messages, + tokenizer=AutoTokenizer.from_pretrained("trl-internal-testing/tiny-GptOssForCausalLM"), + reasoning_effort="low", + model_identity="You are HuggingGPT.", + ) + + # docstyle-ignore + expected = textwrap.dedent(f"""\ + <|start|>system<|message|>You are HuggingGPT. + Knowledge cutoff: 2024-06 + Current date: {strftime("%Y-%m-%d")} + + Reasoning: low + + # Valid channels: analysis, commentary, final. Channel must be included for every message.<|end|><|start|>developer<|message|># Instructions + + Respond in a friendly manner. + + <|end|><|start|>user<|message|>What color is the sky?<|end|><|start|>assistant""") + + assert output["prompt"] == expected + + def test_prompt_completion(self): + messages = { + "prompt": [ + {"role": "system", "content": "Respond in a friendly manner."}, + {"role": "user", "content": "What color is the sky?"}, + ], + "completion": [ + {"role": "assistant", "thinking": "The user asks the color of the sky...", "content": "It is blue."}, + ], + } + output = apply_chat_template( + messages, + tokenizer=AutoTokenizer.from_pretrained("trl-internal-testing/tiny-GptOssForCausalLM"), + reasoning_effort="low", + model_identity="You are HuggingGPT.", + ) + + # docstyle-ignore + expected_prompt = textwrap.dedent(f"""\ + <|start|>system<|message|>You are HuggingGPT. + Knowledge cutoff: 2024-06 + Current date: {strftime("%Y-%m-%d")} + + Reasoning: low + + # Valid channels: analysis, commentary, final. Channel must be included for every message.<|end|><|start|>developer<|message|># Instructions + + Respond in a friendly manner. + + <|end|><|start|>user<|message|>What color is the sky?<|end|><|start|>assistant""") + expected_completion = "<|channel|>analysis<|message|>The user asks the color of the sky...<|end|><|start|>assistant<|channel|>final<|message|>It is blue.<|return|>" + + assert output["prompt"] == expected_prompt + assert output["completion"] == expected_completion + + def test_preference(self): + messages = { + "prompt": [ + {"role": "system", "content": "Respond in a friendly manner."}, + {"role": "user", "content": "What color is the sky?"}, + ], + "chosen": [ + {"role": "assistant", "thinking": "The user asks the color of the sky...", "content": "It is blue."}, + ], + "rejected": [ + {"role": "assistant", "thinking": "The user asks the color of the tree...", "content": "It is green."}, + ], + } + output = apply_chat_template( + messages, + tokenizer=AutoTokenizer.from_pretrained("trl-internal-testing/tiny-GptOssForCausalLM"), + reasoning_effort="low", + model_identity="You are HuggingGPT.", + ) + + # docstyle-ignore + expected_prompt = textwrap.dedent(f"""\ + <|start|>system<|message|>You are HuggingGPT. + Knowledge cutoff: 2024-06 + Current date: {strftime("%Y-%m-%d")} + + Reasoning: low + + # Valid channels: analysis, commentary, final. Channel must be included for every message.<|end|><|start|>developer<|message|># Instructions + + Respond in a friendly manner. + + <|end|><|start|>user<|message|>What color is the sky?<|end|><|start|>assistant""") + expected_chosen = "<|channel|>analysis<|message|>The user asks the color of the sky...<|end|><|start|>assistant<|channel|>final<|message|>It is blue.<|return|>" + expected_rejected = "<|channel|>analysis<|message|>The user asks the color of the tree...<|end|><|start|>assistant<|channel|>final<|message|>It is green.<|return|>" + + assert output["prompt"] == expected_prompt + assert output["chosen"] == expected_chosen + assert output["rejected"] == expected_rejected + + def test_preference_with_implicit_prompt(self): + messages = { + "chosen": [ + {"role": "system", "content": "Respond in a friendly manner."}, + {"role": "user", "content": "What color is the sky?"}, + {"role": "assistant", "thinking": "The user asks the color of the sky...", "content": "It is blue."}, + ], + "rejected": [ + {"role": "system", "content": "Respond in a friendly manner."}, + {"role": "user", "content": "What color is the sky?"}, + {"role": "assistant", "thinking": "The user asks the color of the tree...", "content": "It is green."}, + ], + } + output = apply_chat_template( + messages, + tokenizer=AutoTokenizer.from_pretrained("trl-internal-testing/tiny-GptOssForCausalLM"), + reasoning_effort="low", + model_identity="You are HuggingGPT.", + ) + + # docstyle-ignore + expected_chosen = textwrap.dedent(f"""\ + <|start|>system<|message|>You are HuggingGPT. + Knowledge cutoff: 2024-06 + Current date: {strftime("%Y-%m-%d")} + + Reasoning: low + + # Valid channels: analysis, commentary, final. Channel must be included for every message.<|end|><|start|>developer<|message|># Instructions + + Respond in a friendly manner. + + <|end|><|start|>user<|message|>What color is the sky?<|end|><|start|>assistant<|channel|>analysis<|message|>The user asks the color of the sky...<|end|><|start|>assistant<|channel|>final<|message|>It is blue.<|return|>""") + + # docstyle-ignore + expected_rejected = textwrap.dedent(f"""\ + <|start|>system<|message|>You are HuggingGPT. + Knowledge cutoff: 2024-06 + Current date: {strftime("%Y-%m-%d")} + + Reasoning: low + + # Valid channels: analysis, commentary, final. Channel must be included for every message.<|end|><|start|>developer<|message|># Instructions + + Respond in a friendly manner. + + <|end|><|start|>user<|message|>What color is the sky?<|end|><|start|>assistant<|channel|>analysis<|message|>The user asks the color of the tree...<|end|><|start|>assistant<|channel|>final<|message|>It is green.<|return|>""") + + assert output["chosen"] == expected_chosen + assert output["rejected"] == expected_rejected + + def test_unpaired_preference(self): + messages = { + "prompt": [ + {"role": "system", "content": "Respond in a friendly manner."}, + {"role": "user", "content": "What color is the sky?"}, + ], + "completion": [ + {"role": "assistant", "thinking": "The user asks the color of the sky...", "content": "It is blue."}, + ], + "label": True, + } + output = apply_chat_template( + messages, + tokenizer=AutoTokenizer.from_pretrained("trl-internal-testing/tiny-GptOssForCausalLM"), + reasoning_effort="low", + model_identity="You are HuggingGPT.", + ) + + # docstyle-ignore + expected_prompt = textwrap.dedent(f"""\ + <|start|>system<|message|>You are HuggingGPT. + Knowledge cutoff: 2024-06 + Current date: {strftime("%Y-%m-%d")} + + Reasoning: low + + # Valid channels: analysis, commentary, final. Channel must be included for every message.<|end|><|start|>developer<|message|># Instructions + + Respond in a friendly manner. + + <|end|><|start|>user<|message|>What color is the sky?<|end|><|start|>assistant""") + expected_completion = "<|channel|>analysis<|message|>The user asks the color of the sky...<|end|><|start|>assistant<|channel|>final<|message|>It is blue.<|return|>" + + assert output["prompt"] == expected_prompt + assert output["completion"] == expected_completion + assert output["label"] + + +class TestUnpairPreferenceDataset(TrlTestCase): + paired_dataset = Dataset.from_dict( + { + "prompt": ["The sky is", "The sun is"], + "chosen": [" blue.", " in the sky."], + "rejected": [" green.", " in the sea."], + } + ) + + unpaired_dataset = Dataset.from_dict( + { + "prompt": ["The sky is", "The sun is", "The sky is", "The sun is"], + "completion": [" blue.", " in the sky.", " green.", " in the sea."], + "label": [True, True, False, False], + } + ) + + def test_unpair_preference_dataset(self): + # Test that a paired dataset is correctly converted to unpaired + unpaired_dataset = unpair_preference_dataset(self.paired_dataset) + assert unpaired_dataset.to_dict() == self.unpaired_dataset.to_dict(), ( + "The paired dataset should be converted to unpaired." + ) + + def test_unpair_preference_dataset_dict(self): + # Test that a paired dataset dict is correctly converted to unpaired + paired_dataset_dict = DatasetDict({"abc": self.paired_dataset}) + unpaired_dataset_dict = unpair_preference_dataset(paired_dataset_dict) + assert unpaired_dataset_dict["abc"].to_dict() == self.unpaired_dataset.to_dict(), ( + "The paired dataset should be converted to unpaired." + ) + + def test_maybe_unpair_preference_dataset(self): + # Test that a paired dataset is correctly converted to unpaired with maybe_unpair_preference_dataset + unpaired_dataset = maybe_unpair_preference_dataset(self.paired_dataset) + assert unpaired_dataset.to_dict() == self.unpaired_dataset.to_dict(), ( + "The paired dataset should be converted to unpaired." + ) + + def test_maybe_unpair_preference_dataset_dict(self): + # Test that a paired dataset dict is correctly converted to unpaired with maybe_unpair_preference_dataset + paired_dataset_dict = DatasetDict({"abc": self.paired_dataset}) + unpaired_dataset_dict = maybe_unpair_preference_dataset(paired_dataset_dict) + assert unpaired_dataset_dict["abc"].to_dict() == self.unpaired_dataset.to_dict(), ( + "The paired dataset should be converted to unpaired." + ) + + def test_maybe_unpair_preference_dataset_already_paired(self): + # Test that a paired dataset remains unchanged with maybe_unpair_preference_dataset + unpaired_dataset = maybe_unpair_preference_dataset(self.unpaired_dataset) + assert unpaired_dataset.to_dict() == self.unpaired_dataset.to_dict(), ( + "The unpaired dataset should remain unchanged." + ) + + def test_maybe_unpair_preference_dataset_dict_already_paired(self): + # Test that a paired dataset dict remains unchanged with maybe_unpair_preference_dataset + unpaired_dataset_dict = maybe_unpair_preference_dataset(DatasetDict({"abc": self.unpaired_dataset})) + assert unpaired_dataset_dict["abc"].to_dict() == self.unpaired_dataset.to_dict(), ( + "The unpaired dataset should remain unchanged." + ) + + +class TestExtractPrompt(TrlTestCase): + example_implicit_prompt_conversational = { + "chosen": [ + {"role": "user", "content": "What color is the sky?"}, + {"role": "assistant", "content": "It is blue."}, + ], + "rejected": [ + {"role": "user", "content": "What color is the sky?"}, + {"role": "assistant", "content": "It is green."}, + ], + } + + example_explicit_prompt_conversational = { + "prompt": [ + {"role": "user", "content": "What color is the sky?"}, + ], + "chosen": [ + {"role": "assistant", "content": "It is blue."}, + ], + "rejected": [ + {"role": "assistant", "content": "It is green."}, + ], + } + + example_implicit_prompt_standard = { + "chosen": "The sky is blue.", + "rejected": "The sky is green.", + } + + example_explicit_prompt_standard = { + "prompt": "The sky is", + "chosen": " blue.", + "rejected": " green.", + } + + def test_extract_prompt_conversational(self): + # Test that the prompt is correctly extracted from the dataset + example_extracted_prompt = extract_prompt(self.example_implicit_prompt_conversational) + assert example_extracted_prompt == self.example_explicit_prompt_conversational, ( + "The prompt is not correctly extracted from the dataset." + ) + + def test_maybe_extract_prompt_conversational(self): + # Test that the prompt is correctly extracted from the dataset with maybe_extract_prompt + example_extracted_prompt = maybe_extract_prompt(self.example_implicit_prompt_conversational) + assert example_extracted_prompt == self.example_explicit_prompt_conversational, ( + "The prompt is not correctly extracted from the dataset." + ) + + def test_maybe_extract_prompt_conversational_already_explicit(self): + # Test that the prompt remains unchanged with maybe_extract_prompt + example_extracted_prompt = maybe_extract_prompt(self.example_explicit_prompt_conversational) + assert example_extracted_prompt == self.example_explicit_prompt_conversational, ( + "The prompt should remain unchanged." + ) + + def test_extract_prompt_standard(self): + # Test that the prompt is correctly extracted from the dataset + example_extracted_prompt = extract_prompt(self.example_implicit_prompt_standard) + assert example_extracted_prompt == self.example_explicit_prompt_standard, ( + "The prompt is not correctly extracted from the dataset." + ) + + def test_maybe_extract_prompt_standard(self): + # Test that the prompt is correctly extracted from the dataset with maybe_extract_prompt + example_extracted_prompt = maybe_extract_prompt(self.example_implicit_prompt_standard) + assert example_extracted_prompt == self.example_explicit_prompt_standard, ( + "The prompt is not correctly extracted from the dataset." + ) + + def test_maybe_extract_prompt_standard_already_explicit(self): + # Test that the prompt remains unchanged with maybe_extract_prompt + example_extracted_prompt = maybe_extract_prompt(self.example_explicit_prompt_standard) + assert example_extracted_prompt == self.example_explicit_prompt_standard, "The prompt should remain unchanged." + + +class TestPackDatasetWrapped(TrlTestCase): + def test_with_dataset(self): + examples = { + "input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]], + "attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]], + } + dataset = Dataset.from_dict(examples) + seq_length = 3 + expected_output = { + "input_ids": [[1, 2, 3], [4, 5, 6], [7, 8]], + "attention_mask": [[0, 1, 1], [0, 0, 1], [1, 1]], + } + dataset = pack_dataset(dataset, seq_length, strategy="wrapped") + assert dataset.to_dict() == expected_output + + def test_with_iterable_dataset(self): + examples = { + "input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]], + "attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]], + } + dataset = Dataset.from_dict(examples).to_iterable_dataset() + seq_length = 3 + expected_output = { + "input_ids": [[1, 2, 3], [4, 5, 6], [7, 8]], + "attention_mask": [[0, 1, 1], [0, 0, 1], [1, 1]], + } + dataset = pack_dataset(dataset, seq_length, strategy="wrapped") + num_examples = len(examples[next(iter(examples))]) + assert next(iter(dataset.batch(batch_size=num_examples))) == expected_output + + +class TestPackDatasetBfd(TrlTestCase): + def test_simple(self): + examples = { + "input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]], + } + dataset = Dataset.from_dict(examples) + seq_length = 4 + expected_output = { + "input_ids": [[4, 5, 6, 7], [1, 2, 3, 8]], + "seq_lengths": [[4], [3, 1]], + } + dataset = pack_dataset(dataset, seq_length, strategy="bfd") + assert dataset.to_dict() == expected_output + + def test_with_iterable_dataset(self): + examples = { + "input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]], + } + dataset = Dataset.from_dict(examples).to_iterable_dataset() + seq_length = 4 + expected_output = { + "input_ids": [[4, 5, 6, 7], [1, 2, 3, 8]], + "seq_lengths": [[4], [3, 1]], + } + dataset = pack_dataset(dataset, seq_length, strategy="bfd") + num_examples = len(examples[next(iter(examples))]) + assert next(iter(dataset.batch(batch_size=num_examples))) == expected_output + + def test_with_overlong_0(self): + examples = { + "input_ids": [[1, 2, 3, 4, 5], [6, 7], [8, 9, 10, 11], [12]], + } + dataset = Dataset.from_dict(examples) + seq_length = 4 + expected_output = { + "input_ids": [[1, 2, 3, 4], [8, 9, 10, 11], [6, 7, 5, 12]], + "seq_lengths": [[4], [4], [2, 1, 1]], + } + dataset = pack_dataset(dataset, seq_length, strategy="bfd-requeue") + assert dataset.to_dict() == expected_output + + def test_with_overlong_two_coluns(self): + examples = { + "col1": [[1, -2, 3, -4, 5, -6], [7, -8, 9], [-10, 11, -12], [13, -14, 15, -16]], + "col2": [[-1, 2, -3, 4, -5, 6], [-7, 8, -9], [10, -11, 12], [-13, 14, -15, 16]], + } + dataset = Dataset.from_dict(examples) + seq_length = 4 + expected_output = { + "col1": [[1, -2, 3, -4], [13, -14, 15, -16], [7, -8, 9], [-10, 11, -12], [5, -6]], + "col2": [[-1, 2, -3, 4], [-13, 14, -15, 16], [-7, 8, -9], [10, -11, 12], [-5, 6]], + "seq_lengths": [[4], [4], [3], [3], [2]], + } + dataset = pack_dataset(dataset, seq_length, strategy="bfd-requeue") + assert dataset.to_dict() == expected_output + + def test_with_non_power_of_2(self): + examples = { + "input_ids": [[1, 2, 3, 4, 5], [6], [7, 8, 9, 10], [11, 12, 13]], + } + dataset = Dataset.from_dict(examples) + seq_length = 5 + expected_output = { + "input_ids": [[1, 2, 3, 4, 5], [7, 8, 9, 10, 6], [11, 12, 13]], + "seq_lengths": [[5], [4, 1], [3]], + } + dataset = pack_dataset(dataset, seq_length, strategy="bfd-requeue") + assert dataset.to_dict() == expected_output + + def test_default_no_requeue(self): + """Test default 'bfd' strategy for SFT datasets (truncates overflow).""" + examples = { + "input_ids": [[1, 2, 3, 4, 5], [6, 7], [8, 9, 10, 11], [12]], + } + dataset = Dataset.from_dict(examples) + seq_length = 4 + # With default 'bfd' strategy, overflow tokens are discarded + expected_output = { + "input_ids": [[1, 2, 3, 4], [8, 9, 10, 11], [6, 7, 12]], + "seq_lengths": [[4], [4], [2, 1]], + } + dataset = pack_dataset(dataset, seq_length, strategy="bfd") + assert dataset.to_dict() == expected_output + + +class TestTruncateExamples(TrlTestCase): + def test_with_dataset(self): + examples = { + "input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]], + "attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]], + } + dataset = Dataset.from_dict(examples) + max_length = 2 + expected_output = { + "input_ids": [[1, 2], [4, 5], [8]], + "attention_mask": [[0, 1], [0, 0], [1]], + } + dataset = truncate_dataset(dataset, max_length) + assert dataset.to_dict() == expected_output + + def test_with_iterable_dataset(self): + examples = { + "input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]], + "attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]], + } + dataset = Dataset.from_dict(examples).to_iterable_dataset() + max_length = 2 + expected_output = { + "input_ids": [[1, 2], [4, 5], [8]], + "attention_mask": [[0, 1], [0, 0], [1]], + } + dataset = truncate_dataset(dataset, max_length) + num_examples = len(examples[next(iter(examples))]) + assert next(iter(dataset.batch(batch_size=num_examples))) == expected_output + + def test_with_extra_column(self): + examples = { + "input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]], + "attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]], + "my_column": ["a", "b", "c"], + } + dataset = Dataset.from_dict(examples) + max_length = 2 + expected_output = { + "input_ids": [[1, 2], [4, 5], [8]], + "attention_mask": [[0, 1], [0, 0], [1]], + "my_column": ["a", "b", "c"], + } + dataset = truncate_dataset(dataset, max_length) + assert dataset.to_dict() == expected_output + + +class TestMaybeConvertToChatML(TrlTestCase): + def test_with_conversations_key(self): + # Particular case where the key is "conversations": we rename it to "messages" + example = { + "conversations": [ + {"from": "user", "value": "What color is the sky?"}, + {"from": "assistant", "value": "It is blue."}, + ] + } + expected_output = { + "messages": [ + {"role": "user", "content": "What color is the sky?"}, + {"role": "assistant", "content": "It is blue."}, + ] + } + assert maybe_convert_to_chatml(example) == expected_output + + def test_without_conversations_key(self): + # Same as before, but we don't rename the keys + example = { + "prompt": [{"from": "user", "value": "What color is the sky?"}], + "completion": [{"from": "assistant", "value": "It is blue."}], + } + expected_output = { + "prompt": [{"role": "user", "content": "What color is the sky?"}], + "completion": [{"role": "assistant", "content": "It is blue."}], + } + assert maybe_convert_to_chatml(example) == expected_output + + def test_not_conversional(self): + # When not needed, the example should remain unchanged + example = {"text": "The sky is blue."} + assert maybe_convert_to_chatml(example) == example + + def test_already_chatml(self): + # When the example is already in ChatML format, it should remain unchanged + example = { + "messages": [ + {"role": "user", "content": "What color is the sky?"}, + {"role": "assistant", "content": "It is blue."}, + ] + } + assert maybe_convert_to_chatml(example) == example diff --git a/ICL/RL/trl_source/tests/test_grpo_trainer.py b/ICL/RL/trl_source/tests/test_grpo_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..5cc5f49814628ac7d14bf55bea6abf2d06d39b1b --- /dev/null +++ b/ICL/RL/trl_source/tests/test_grpo_trainer.py @@ -0,0 +1,2803 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import os +import warnings +from collections.abc import Callable +from unittest.mock import patch + +import numpy as np +import pytest +import torch +import transformers +from accelerate.utils.memory import release_memory +from datasets import Dataset, Features, Image, Value, load_dataset +from packaging.version import Version +from transformers import ( + AutoModelForCausalLM, + AutoModelForImageTextToText, + AutoModelForSequenceClassification, + AutoProcessor, + AutoTokenizer, + BitsAndBytesConfig, +) +from transformers.testing_utils import backend_empty_cache, torch_device +from transformers.utils import is_peft_available + +from trl import GRPOConfig, GRPOTrainer +from trl.import_utils import is_liger_kernel_available +from trl.trainer.utils import get_kbit_device_map + +from .testing_utils import ( + TrlTestCase, + require_ampere_or_newer, + require_bitsandbytes, + require_jmespath, + require_kernels, + require_liger_kernel, + require_peft, + require_torch_accelerator, + require_vision, + require_vllm, +) + + +if is_peft_available(): + from peft import LoraConfig, PeftModel, get_peft_model + + +def multiply_tool(a: int, b: int) -> int: + """ + Multiplies two integers. + + Args: + a: The first integer. + b: The second integer. + + Returns: + The product of the two integers. + """ + return a * b + + +async def async_multiply_tool(a: int, b: int) -> int: + """ + Asynchronously multiplies two integers. + + Args: + a: The first integer. + b: The second integer. + + Returns: + The product of the two integers. + """ + return a * b + + +class TestGetHighEntropyMask(TrlTestCase): + def get_high_entropy_mask(self, entropies, mask, threshold): + """Helper method to test the get_high_entropy_mask functionality.""" + # Create a mock trainer with minimal setup + from unittest.mock import Mock + + # Create a mock accelerator + mock_accelerator = Mock() + mock_accelerator.num_processes = 1 # Single process for testing + + # Create a minimal trainer instance just to access the method + trainer = Mock(spec=GRPOTrainer) + trainer.accelerator = mock_accelerator + trainer.accelerator.gather = lambda x: x + trainer.accelerator.pad_across_processes = lambda x, dim, pad_index: x + + # Call the actual method from GRPOTrainer + return GRPOTrainer.get_high_entropy_mask(trainer, entropies, mask, threshold) + + def test_compute_entropy_mask_0(self): + # We have a total of 12 tokens out of which 10 are non-pad. + # for a top_entropy_quantile of 0.8, we expect the top 20% i.e 2 non-pad tokens corresponding to + # the highest entropy to be unmasked. + # In our example these will be the tokens corresponding to the entropies 0.9 and 1.0 since 1.1 and 1.2 are pad + # tokens they are excluded from the entropy threshold calculation. + entropies = torch.tensor([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], [0.7, 0.8, 0.9, 1.0, 1.1, 1.2]]) + mask = torch.tensor([[1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 0, 0]]) + entropy_mask = self.get_high_entropy_mask(entropies, mask, threshold=0.8) + expected_mask = torch.tensor([[0, 0, 0, 0, 0, 0], [0, 0, 1, 1, 0, 0]], dtype=torch.bool) + torch.testing.assert_close(entropy_mask, expected_mask) + + def test_compute_entropy_mask_1(self): + # Another example with a different set of entropies and a different mask. + entropies = torch.tensor([[0.1, 0.2, 0.3, 1.4, 0.5, 0.14], [0.5, 0.6, 0.7, 0.8, 0.9, 1.0]]) + mask = torch.tensor([[1, 1, 1, 1, 0, 0], [1, 1, 1, 1, 0, 0]]) + entropy_mask = self.get_high_entropy_mask(entropies, mask, threshold=0.8) + expected_mask = torch.tensor([[0, 0, 0, 1, 0, 0], [0, 0, 0, 1, 0, 0]], dtype=torch.bool) + torch.testing.assert_close(entropy_mask, expected_mask) + + def test_compute_entropy_mask_lower_threshold(self): + # For a threshold of 0.5 we expect the top half of the non-pad tokens to be unmasked. + entropies = torch.tensor([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], [0.7, 0.8, 0.9, 1.0, 1.1, 1.2]]) + mask = torch.tensor([[1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 0, 0]]) + entropy_mask = self.get_high_entropy_mask(entropies, mask, threshold=0.5) + expected_mask = torch.tensor([[0, 0, 0, 0, 0, 1], [1, 1, 1, 1, 0, 0]], dtype=torch.bool) + torch.testing.assert_close(entropy_mask, expected_mask) + + def test_compute_entropy_threshold_0(self): + # If the threshold is 0.0 then we expect the mask to be all ones for non-pad tokens. + entropies = torch.tensor([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], [0.7, 0.8, 0.9, 1.0, 1.1, 1.2]]) + mask = torch.tensor([[1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 0, 0]]) + entropy_mask = self.get_high_entropy_mask(entropies, mask, threshold=0.0) + expected_mask = torch.tensor([[1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 0, 0]], dtype=torch.bool) + torch.testing.assert_close(entropy_mask, expected_mask) + + def test_compute_entropy_threshold_1(self): + # If the threshold is 1.0 then we expect the mask to be all zeros BUT ONE VALUE. + entropies = torch.tensor([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], [0.7, 0.8, 0.9, 1.0, 1.1, 1.2]]) + mask = torch.tensor([[1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 0, 0]]) + entropy_mask = self.get_high_entropy_mask(entropies, mask, threshold=1.0) + expected_mask = torch.tensor([[0, 0, 0, 0, 0, 0], [0, 0, 0, 1, 0, 0]], dtype=torch.bool) + torch.testing.assert_close(entropy_mask, expected_mask) + + def test_compute_entropy_all_masked(self): + # If there are no non-pad tokens we expect the mask to be all zeros. + entropies = torch.tensor([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], [0.7, 0.8, 0.9, 1.0, 1.1, 1.2]]) + mask = torch.tensor([[0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0]]) + entropy_mask = self.get_high_entropy_mask(entropies, mask, threshold=0.5) + expected_mask = torch.tensor([[0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0]], dtype=torch.bool) + torch.testing.assert_close(entropy_mask, expected_mask) + + +class TestGRPOTrainer(TrlTestCase): + def test_init_minimal(self): + # Test that GRPOTrainer can be instantiated with only model, reward_model and train_dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + train_dataset=dataset, + ) + + @pytest.mark.parametrize("config_name", ["standard_prompt_only", "conversational_prompt_only"]) + def test_training(self, config_name): + dataset = load_dataset("trl-internal-testing/zen", config_name, split="train") + + training_args = GRPOConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=8, # reduce the completion length to reduce memory usage + report_to="none", + ) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." + + @pytest.mark.parametrize("loss_type", ["bnpo", "dr_grpo", "dapo", "cispo", "sapo", "luspo"]) + def test_training_loss_types(self, loss_type): + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + training_args = GRPOConfig( + output_dir=self.tmp_dir, + importance_sampling_level="sequence" if loss_type == "luspo" else "token", + learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=32, # reduce the completion length to reduce memory usage + gradient_accumulation_steps=2, # set to 2 to test than DAPO can operate with accumulated batch + loss_type=loss_type, + report_to="none", + ) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." + + def test_training_with_eval(self): + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only") + + training_args = GRPOConfig( + output_dir=self.tmp_dir, + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + per_device_eval_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=8, # reduce the completion length to reduce memory usage + eval_strategy="steps", + eval_steps=2, + report_to="none", + ) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset["train"], + eval_dataset=dataset["test"], + ) + + trainer.train() + + def test_training_with_num_generations_eval(self): + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only") + + training_args = GRPOConfig( + output_dir=self.tmp_dir, + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + per_device_eval_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=8, # reduce the completion length to reduce memory usage + num_generations_eval=1, + eval_strategy="steps", + eval_steps=2, + report_to="none", + ) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset["train"], + eval_dataset=dataset["test"], + ) + + trainer.train() + + # Regression test for eval_on_start with loss_type="grpo" (one of the loss types that depends on + # current_gradient_accumulation_steps): evaluation runs before the first training step, when that value is still + # unset. Previously this caused the initial eval to crash. + def test_training_eval_on_start(self): + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only") + training_args = GRPOConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + per_device_eval_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=8, # reduce the completion length to reduce memory usage + loss_type="grpo", + eval_strategy="steps", + eval_steps=2, + eval_on_start=True, + report_to="none", + ) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset["train"], + eval_dataset=dataset["test"], + ) + trainer.train() + + def test_training_multiple_iterations(self): + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + training_args = GRPOConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=8, # reduce the completion length to reduce memory usage + num_iterations=2, + report_to="none", + ) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." + + @require_peft + def test_training_peft_config(self): + model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", dtype="float32") + base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()] + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + training_args = GRPOConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=8, # reduce the completion length to reduce memory usage + report_to="none", + ) + trainer = GRPOTrainer( + model=model, + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + peft_config=LoraConfig(), + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check that the peft params have changed and the base model params have not changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + if n in base_param_names: # We expect the base model params to be the same + torch.testing.assert_close(param, new_param), f"Parameter {n} has changed." + elif "base_layer" not in n: # We expect the peft params to be different (except for the base layer) + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed." + + @require_peft + def test_training_peft_model(self): + model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", dtype="float32") + base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()] + lora_config = LoraConfig() + model = get_peft_model(model, lora_config) + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + training_args = GRPOConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=8, # reduce the completion length to reduce memory usage + report_to="none", + ) + trainer = GRPOTrainer( + model=model, + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check that the peft params have changed and the base model params have not changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + if n in base_param_names: # We expect the base model params to be the same + torch.testing.assert_close(param, new_param), f"Parameter {n} has changed." + elif "base_layer" not in n and "ref" not in n: # and the peft params to be different (except base and ref) + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed." + + # In practice, this test is the same as `test_training_peft_config`, since gradient checkpointing is enabled by + # default in `GRPOTrainer`. We keep it as a regression guard: if the default ever changes, we still explicitly test + # PEFT + gradient checkpointing, which has caused issues in the past. + @require_peft + def test_training_peft_with_gradient_checkpointing(self): + model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", dtype="float32") + base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()] + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + training_args = GRPOConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=8, # reduce the completion length to reduce memory usage + gradient_checkpointing=True, # enable gradient checkpointing + report_to="none", + ) + trainer = GRPOTrainer( + model=model, + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + peft_config=LoraConfig(), + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check that the peft params have changed and the base model params have not changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + if n in base_param_names: # We expect the base model params to be the same + torch.testing.assert_close(param, new_param), f"Parameter {n} has changed." + elif "base_layer" not in n: # We expect the peft params to be different (except for the base layer) + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed." + + def test_training_different_reward_model(self): + # Use a reward model different from the model: different chat template, tokenization, etc. + dataset = load_dataset("trl-internal-testing/zen", "conversational_prompt_only", split="train") + reward_model_id = "trl-internal-testing/tiny-LlamaForSequenceClassification-3.2" + reward_model = AutoModelForSequenceClassification.from_pretrained(reward_model_id) + reward_tokenizer = AutoTokenizer.from_pretrained(reward_model_id) + # By default, the trainer uses the eos token as the padding token. However, for Llama models, the eos token + # appears in the chat template. Using it as a pad token disrupts the reward calculation, as the calculation + # considers the score of the last token before the first pad token. To ensure correct reward calculations, + # we use a separate pad token instead. + reward_tokenizer.pad_token = "<|finetune_right_pad_id|>" + + training_args = GRPOConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=8, # reduce the completion length to reduce memory usage + report_to="none", + ) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs=reward_model, + args=training_args, + train_dataset=dataset, + reward_processing_classes=reward_tokenizer, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." + + def test_training_reward_func_standard(self): + # Test if trainer can handle reward function with standard format + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + def reward_func(completions, **kwargs): + """Reward function that rewards longer completions.""" + return [float(len(completion)) for completion in completions] + + training_args = GRPOConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=8, # reduce the completion length to reduce memory usage + report_to="none", + ) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs=reward_func, + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." + + def test_training_reward_func_conversational(self): + # Test if trainer can handle reward function with conversational format + dataset = load_dataset("trl-internal-testing/zen", "conversational_prompt_only", split="train") + + def reward_func(completions, **kwargs): + """Reward function that gives higher scores to longer completion content.""" + completion_contents = [completion[0]["content"] for completion in completions] + return [float(len(content)) for content in completion_contents] + + training_args = GRPOConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=8, # reduce the completion length to reduce memory usage + report_to="none", + ) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs=reward_func, + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." + + def test_training_multiple_reward_funcs(self): + # Test that GRPOTrainer can be instantiated with multiple reward functions + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + def reward_func1(completions, **kwargs): + """Reward function that rewards longer completions.""" + return [float(len(completion)) for completion in completions] + + def reward_func2(completions, **kwargs): + """Reward function that rewards completions with more unique letters.""" + return [float(len(set(completion))) for completion in completions] + + training_args = GRPOConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=8, # reduce the completion length to reduce memory usage + report_to="none", + ) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs=[reward_func1, reward_func2], + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." + + def test_training_sync_and_async_reward_funcs(self): + # Test that GRPOTrainer can be instantiated with multiple reward functions one of which is async + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + def sync_reward_func1(completions, **kwargs): + """Reward function that rewards longer completions.""" + return [float(len(completion)) for completion in completions] + + def sync_reward_func2(completions, **kwargs): + return [1 for _ in completions] + + async def async_reward_func(completions, **kwargs): + """Async Reward function that rewards completions with more unique letters.""" + return [float(len(set(completion))) for completion in completions] + + training_args = GRPOConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=8, # reduce the completion length to reduce memory usage + report_to="none", + ) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs=[sync_reward_func1, sync_reward_func2, async_reward_func], + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." + + def test_training_multiple_reward_funcs_with_None_output(self): + """Test that a valid math reward function is processed correctly while the code reward function returns None.""" + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + def applicable_reward_func(completions, **kwargs): + """A reward function that rewards longer completions.""" + return [float(len(completion)) for completion in completions] + + def non_applicable_reward_func(completions, **kwargs): + """A reward function that returns None for all inputs, as it is not applicable to this sample.""" + return [None] * len(completions) + + training_args = GRPOConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates + per_device_train_batch_size=3, + num_generations=3, + max_completion_length=8, + report_to="none", + ) + + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs=[ + applicable_reward_func, + non_applicable_reward_func, + ], # One applicable, one non applicable + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = { + n: param.clone() for n, param in trainer.model.named_parameters() if param.requires_grad + } + + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." + + def test_training_multiple_reward_funcs_with_weights(self): + """Test that GRPOTrainer can handle multiple reward functions with weights.""" + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + def reward_func1(completions, **kwargs): + """Reward function that rewards longer completions.""" + return [float(len(completion)) for completion in completions] + + def reward_func2(completions, **kwargs): + """Reward function that rewards completions with more unique letters.""" + return [float(len(set(completion))) for completion in completions] + + training_args = GRPOConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=8, # reduce the completion length to reduce memory usage + report_to="none", + reward_weights=[0.7, 0.3], # weight of reward_func1 and reward_func2 respectively + ) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs=[reward_func1, reward_func2], + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + # Check that training logs contain both reward metrics + assert trainer.state.log_history[-1]["train_loss"] is not None + assert "rewards/reward_func1/mean" in trainer.state.log_history[-1] + assert "rewards/reward_func1/std" in trainer.state.log_history[-1] + assert "rewards/reward_func2/mean" in trainer.state.log_history[-1] + assert "rewards/reward_func2/std" in trainer.state.log_history[-1] + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." + + def test_training_multiple_mixed_reward_funcs(self): + # Test if the trainer can handle a mix of reward functions and reward models + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + def reward_func(completions, **kwargs): + """Reward function that rewards longer completions.""" + return [float(len(completion)) for completion in completions] + + training_args = GRPOConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=8, # reduce the completion length to reduce memory usage + report_to="none", + ) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs=[reward_func, "trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5"], + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." + + def test_training_reward_func_additional_column(self): + # Test if trainer can handle reward function that rely on additional columns in the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + # Add a column to the dataset (dummy example, the column could be anything) + some_values = list(range(len(dataset))) + dataset = dataset.add_column("some_values", some_values) + + def reward_func(completions, some_values, **kwargs): + """Reward function that rewards completions with lengths closer to the values in some_values.""" + return [ + float(abs(len(completion) - value)) for completion, value in zip(completions, some_values, strict=True) + ] + + training_args = GRPOConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=8, # reduce the completion length to reduce memory usage + report_to="none", + ) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs=reward_func, + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." + + def test_training_with_sync_ref_model(self): + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + training_args = GRPOConfig( + output_dir=self.tmp_dir, + beta=0.1, # ensure ref model is created so sync can update it + learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=8, # reduce the completion length to reduce memory usage + sync_ref_model=True, + ref_model_sync_steps=2, # reduce sync steps to ensure a sync happens + report_to="none", + ) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + assert trainer.ref_model is not None + previous_ref_params = {n: param.clone() for n, param in trainer.ref_model.named_parameters()} + + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." + new_ref_param = trainer.ref_model.get_parameter(n) + assert not torch.equal(previous_ref_params[n], new_ref_param), f"Ref Parameter {n} has not changed." + + def test_training_beta_non_zero(self): + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + training_args = GRPOConfig( + output_dir=self.tmp_dir, + beta=0.1, # set beta to non-zero value to test the case where the reference model is used + learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=8, # reduce the completion length to reduce memory usage + report_to="none", + ) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." + + def test_get_off_policy_mask(self): + """ + Test the logic of off-policy masking: + - Keep if Advantage >= 0 + - Keep if KL <= threshold + - Drop if Advantage < 0 AND KL > threshold + """ + mask = torch.ones((3, 4)) # B=3 sequences, T=4 tokens + + advantages = torch.tensor([1.0, -1.0, -1.0]).unsqueeze(-1) + sampling_per_token_logps = torch.zeros((3, 4)) + per_token_logps = torch.zeros((3, 4)) + + per_token_logps[0, :] = -2.0 # Pos adv + High KL (0−(−2)=2) -> Keep + per_token_logps[1, :] = -0.5 # Neg adv + Low KL (0.5) -> Keep + per_token_logps[2, :] = -2.0 # Neg adv + High KL (2.0) -> Drop + + off_policy_threshold = 1.0 + + expected_mask = torch.tensor([[1.0], [1.0], [0.0]]) + + off_policy_mask = GRPOTrainer.get_off_policy_mask( + advantages, per_token_logps, sampling_per_token_logps, mask, off_policy_threshold + ) + + torch.testing.assert_close(off_policy_mask, expected_mask) + + def test_get_off_policy_mask_padding(self): + """Test that padding is correctly ignored in KL calculation.""" + mask = torch.tensor([[1.0, 1.0, 0.0, 0.0]]) # 2 valid tokens + advantages = torch.tensor([[-1.0]]) # Negative advantage + + sampling_per_token_logps = torch.zeros((1, 4)) + per_token_logps = torch.zeros((1, 4)) + + # Valid tokens have High KL (2.0) + per_token_logps[0, 0] = -2.0 + per_token_logps[0, 1] = -2.0 + + # Padding tokens have abnormal values (should be ignored) + per_token_logps[0, 2] = -10_000.0 + per_token_logps[0, 3] = 10_000.0 + + off_policy_threshold = 1.0 + + # Avg KL on valid tokens = (2+2)/2 = 2.0 > 1.0 -> Drop + expected_mask = torch.tensor([[0.0]]) + + off_policy_mask = GRPOTrainer.get_off_policy_mask( + advantages, per_token_logps, sampling_per_token_logps, mask, off_policy_threshold + ) + + torch.testing.assert_close(off_policy_mask, expected_mask) + + # Now test with Low KL on valid tokens + per_token_logps[0, 0] = -0.5 + per_token_logps[0, 1] = -0.5 + # Avg KL = 0.5 <= 1.0 -> Keep + expected_mask_keep = torch.tensor([[1.0]]) + + off_policy_mask_keep = GRPOTrainer.get_off_policy_mask( + advantages, per_token_logps, sampling_per_token_logps, mask, off_policy_threshold + ) + + torch.testing.assert_close(off_policy_mask_keep, expected_mask_keep) + + def test_training_with_off_policy_mask(self): + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + training_args = GRPOConfig( + output_dir=self.tmp_dir, + off_policy_mask_threshold=0.5, + learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=8, # reduce the completion length to reduce memory usage + report_to="none", + ) + + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." + + @require_liger_kernel + @pytest.mark.xfail(reason="Off-Policy Masking isn't compatible with Liger yet.") + def test_training_with_off_policy_mask_with_liger(self): + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + training_args = GRPOConfig( + output_dir=self.tmp_dir, + off_policy_mask_threshold=0.5, + use_liger_kernel=True, + learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=8, # reduce the completion length to reduce memory usage + report_to="none", + ) + + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." + + def test_training_with_bias_correction_kl(self): + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + training_args = GRPOConfig( + output_dir=self.tmp_dir, + beta=0.1, # set beta to non-zero value to test the case where the reference model is used + use_bias_correction_kl=True, + learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=8, # reduce the completion length to reduce memory usage + report_to="none", + ) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." + + @pytest.mark.parametrize( + "model_name", + ["trl-internal-testing/tiny-Qwen3ForCausalLM", "trl-internal-testing/tiny-Gemma2ForCausalLM"], + # Gemma2 has the input word embeddings and lm_head tied, Qwen3 does not + ) + def test_training_with_cast_lm_head_to_fp32(self, model_name): + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + training_args = GRPOConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates + per_device_train_batch_size=3, + num_generations=3, + max_completion_length=8, + report_to="none", + cast_lm_head_to_fp32=True, + ) + trainer = GRPOTrainer( + model=model_name, + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + ) + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + assert trainer.model.lm_head.weight.dtype == torch.float32 + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." + + def test_training_with_entropy_filter(self): + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + training_args = GRPOConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=8, # reduce the completion length to reduce memory usage + report_to="none", + top_entropy_quantile=0.2, + ) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." + + @require_peft + @require_vllm + @pytest.mark.skip(reason="We should add a mock for the vLLM server.") + def test_training_vllm_and_peft(self): + """Test that training works with vLLM for generation.""" + model = AutoModelForCausalLM.from_pretrained( + "Qwen/Qwen2.5-0.5B-Instruct", dtype="float32" + ) # tiny model is too small for vLLM + base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()] + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + training_args = GRPOConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=8, # reduce the completion length to reduce memory usage + report_to="none", + use_vllm=True, + ) + lora_config = LoraConfig( + target_modules="all-linear", + # test with non-default modules as it adds extra keys in state_dict that we need to handle + modules_to_save=["embed_tokens", "lm_head"], + ) + trainer = GRPOTrainer( + model=model, + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + peft_config=lora_config, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check that the peft params have changed and the base model params have not changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + if n in base_param_names: # We expect the base model params to be the same + torch.testing.assert_close(param, new_param), f"Parameter {n} has changed." + elif "base_layer" not in n and "original_module" not in n: + # We expect the peft params to be different (except for the base layer) + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed." + + @require_vllm + @pytest.mark.skip(reason="We should add a mock for the vLLM server.") + def test_training_vllm_structured_outputs(self): + """Test that training works with vLLM for generation with structured outputs.""" + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + training_args = GRPOConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=8, # reduce the completion length to reduce memory usage + report_to="none", + use_vllm=True, + vllm_structured_outputs_regex=r"\n.*\n\n\n.*\n", + ) + trainer = GRPOTrainer( + model="Qwen/Qwen2.5-0.5B-Instruct", # tiny model is too small for vLLM + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." + + @require_vllm + @pytest.mark.skip(reason="We should add a mock for the vLLM server.") + def test_training_vllm_importance_sampling_correction(self): + """Test that training works with vLLM for generation with structured outputs.""" + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + training_args = GRPOConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates + per_device_train_batch_size=3, + num_generations=3, + max_completion_length=8, + report_to="none", + use_vllm=True, + vllm_importance_sampling_correction=True, + vllm_importance_sampling_cap=3.0, + ) + trainer = GRPOTrainer( + model="Qwen/Qwen2.5-0.5B-Instruct", # tiny model is too small for vLLM + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." + + def test_training_with_additional_generation_kwargs(self): + """Test that training works with additional generation kwargs.""" + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + training_args = GRPOConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=8, # reduce the completion length to reduce memory usage + report_to="none", + top_p=0.9, + top_k=10, + min_p=0.01, + repetition_penalty=1.1, + ) + + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." + + @require_vllm + @pytest.mark.skip(reason="We should add a mock for the vLLM server.") + def test_training_vllm_with_additional_generation_kwargs(self): + """Test that training works with vLLM and additional generation kwargs.""" + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + training_args = GRPOConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=8, # reduce the completion length to reduce memory usage + report_to="none", + use_vllm=True, + top_p=0.9, + top_k=10, + min_p=0.01, + repetition_penalty=1.1, + ) + + trainer = GRPOTrainer( + model="Qwen/Qwen2.5-0.5B-Instruct", # tiny model is too small for vLLM + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." + + def test_training_normalize_then_sum_aggregation(self): + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + def reward_func1(completions, **kwargs): + """Reward function that rewards longer completions.""" + return [float(len(completion)) for completion in completions] + + def reward_func2(completions, **kwargs): + """Reward function that rewards completions with more unique letters.""" + return [float(len(set(completion))) for completion in completions] + + training_args = GRPOConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=8, # reduce the completion length to reduce memory usage + multi_objective_aggregation="normalize_then_sum", + report_to="none", + ) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs=[reward_func1, reward_func2], + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." + + @pytest.mark.parametrize("scale_rewards", [False, "group", "batch", True, "none"]) + def test_training_scale_rewards(self, scale_rewards): + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + training_args = GRPOConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=8, # reduce the completion length to reduce memory usage + scale_rewards=scale_rewards, + report_to="none", + ) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." + + @patch("transformers.generation.utils.GenerationMixin.generate") + def test_training_with_mask_truncated_completions(self, mock_generate): + """Test that training works with mask_truncated_completions=True parameter.""" + + # We mock the generate method because the model's random weights make it extremely unlikely to produce a + # sequence containing the EOS token within the allowed max_completion_length. As a result, all tokens are + # masked in the loss, the model doesn't update, and the final check (which verifies the update) fails. + def fake_generate(input_ids, **kwargs): + # pad_token_id = 151643; eos_token_id = 151645 + completion_ids = torch.tensor( + [ + [1, 2, 3, 4, 5, 6, 7, 8], # this one is truncated + [9, 10, 11, 151645, 151643, 151643, 151643, 151643], # this one contains eos + [12, 13, 14, 15, 16, 17, 18, 151645], # particular case, eos is generated just within the limit + ], + device=input_ids.device, + ) + return torch.cat([input_ids, completion_ids], dim=1) + + mock_generate.side_effect = fake_generate + + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + training_args = GRPOConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=8, # reduce the completion length to reduce memory usage + mask_truncated_completions=True, # Enable masking of truncated completions + report_to="none", + ) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." + + def test_training_with_mask_truncated_completions_all_masked(self): + """ + Test that when all generated completions are truncated (i.e., none contain an EOS token), and + mask_truncated_completions=True, the model receives no effective learning signal and therefore does not update + its parameters. + + Here, we don't mock the generate method, be we rely on the fact that the model the probability of generating + the EOS token is extremely low, so all generated completions are truncated. + """ + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + training_args = GRPOConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=8, # reduce the completion length to reduce memory usage + mask_truncated_completions=True, # Enable masking of truncated completions + report_to="none", + ) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert torch.equal(param, new_param), f"Parameter {n} has changed." + + def test_warning_raised_all_rewards_none(self, caplog): + """Test that a proper warning is raised when all rewards are None.""" + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + def always_none_reward_func(completions, **kwargs): + """Reward function that always returns None.""" + return [None] * len(completions) + + training_args = GRPOConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=8, # reduce the completion length to reduce memory usage + report_to="none", + ) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs=always_none_reward_func, + args=training_args, + train_dataset=dataset, + ) + + with caplog.at_level("WARNING", logger="trl.trainer.grpo_trainer"): + trainer.train() + + expected_warning = "All reward functions returned None for the following kwargs:" + assert expected_warning in caplog.text + + def test_training_num_generations_larger_than_batch_size(self): + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + training_args = GRPOConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + max_completion_length=8, # reduce the completion length to reduce memory usage + num_generations=6, # the number of generations is larger than the batch size, but + gradient_accumulation_steps=2, # gradient accumulation should allow that + report_to="none", + ) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." + + def test_training_delta_clipping(self): + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + training_args = GRPOConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=8, # reduce the completion length to reduce memory usage + delta=2.0, # set delta to a non-None value + report_to="none", + ) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." + + def test_training_multiple_dataloader_workers(self): + # Pytest/CI often starts background threads before tests run. With Python 3.12, using the default "fork" start + # method in a multi-threaded process emits a DeprecationWarning and may deadlock. + # + # We force "spawn" here to make multiprocessing safe under pytest when DataLoader workers are enabled. This is + # test-environment–specific and not required by the training logic itself. + # + # This means the test does not cover "fork". However, "spawn" is stricter (requires full picklability and clean + # state) and avoids fork-after-threads issues that pytest cannot reliably test anyway. Fork-specific behavior, + # if needed, should be tested in a clean process outside pytest. + torch.multiprocessing.set_start_method("spawn", force=True) + + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + training_args = GRPOConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=8, # reduce the completion length to reduce memory usage + dataloader_num_workers=2, # use multiple dataloader workers + report_to="none", + ) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." + + def test_training_with_generation_kwargs(self): + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + training_args = GRPOConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=8, # reduce the completion length to reduce memory usage + # Pass gen kwargs + generation_kwargs={"do_sample": True, "top_k": 50, "num_beams": 2, "length_penalty": -0.1}, + report_to="none", + ) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." + + def test_training_with_reward_func_accessing_trainer_state(self): + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + def reward_func(completions, **kwargs): + trainer_state = kwargs.get("trainer_state") + assert trainer_state is not None + # transformers.TrainerState instance should have a `global_step` property. + assert hasattr(trainer_state, "global_step") + return [float(len(set(completion))) for completion in completions] + + training_args = GRPOConfig( + output_dir=self.tmp_dir, + per_device_train_batch_size=2, + num_generations=2, + max_completion_length=8, + report_to="none", + ) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs=reward_func, + args=training_args, + train_dataset=dataset, + ) + trainer.train() + + def test_prepare_input_called_with_correct_data(self): + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + training_args = GRPOConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates + max_completion_length=8, # reduce the completion length to reduce memory usage + gradient_accumulation_steps=3, # can be anything in this test + # steps_per_generation*per_device_train_batch_size=24 is divisible by num_generations=4 + steps_per_generation=4, + num_generations=4, + per_device_train_batch_size=6, # reduce the batch size to reduce memory usage + num_iterations=2, + shuffle_dataset=False, + report_to="none", + ) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + ) + # steps_per_generation=4, per_device_train_batch_size=6 and num_generations=4, so we expect a + # generation batch of 24 samples (steps_per_generation * per_device_train_batch_size), containing 6 + # different prompts (steps_per_generation * per_device_train_batch_size // num_generations), each repeated + # 4 times (num_generations). + expected_first_generation_batch = ( + [{"prompt": "Beautiful is better than"}] * 4 + + [{"prompt": "Explicit is"}] * 4 + + [{"prompt": "Simple is better"}] * 4 + + [{"prompt": "Complex"}] * 4 + + [{"prompt": "Flat is better than"}] * 4 + + [{"prompt": "Sparse is better"}] * 4 + ) + expected_second_generation_batch = ( + [{"prompt": "Readability"}] * 4 + + [{"prompt": "Special cases aren't special"}] * 4 + + [{"prompt": "Although practicality beats"}] * 4 + + [{"prompt": "Errors should never"}] * 4 + + [{"prompt": "Unless explicitly"}] * 4 + + [{"prompt": "In the face of ambiguity, refuse"}] * 4 + ) + + with patch.object(GRPOTrainer, "training_step", wraps=trainer.training_step) as mock_prepare: + trainer.train() + # 3 epochs * 2 iterations * 2 generation batches to cover the dataset * 4 steps_per_generation + assert mock_prepare.call_count == 48 + for i in range(0, 8): # Generation batch repeated 8 times (steps_per_generation*num_iterations) + assert mock_prepare.call_args_list[i].args[1] == expected_first_generation_batch + for i in range(8, 16): + assert mock_prepare.call_args_list[i].args[1] == expected_second_generation_batch + + @pytest.mark.parametrize( + "model_id", + [ + "trl-internal-testing/tiny-Gemma3ForConditionalGeneration", + "trl-internal-testing/tiny-LlavaNextForConditionalGeneration", + "trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration", + "trl-internal-testing/tiny-Qwen2VLForConditionalGeneration", + # "trl-internal-testing/tiny-SmolVLMForConditionalGeneration", seems not to support bf16 properly + ], + ) + @require_vision + def test_training_vlm(self, model_id): + dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train") + + def reward_func(completions, **kwargs): + """Reward function that rewards longer completions.""" + return [float(len(completion[0]["content"])) for completion in completions] + + training_args = GRPOConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=8, # reduce the completion length to reduce memory usage + report_to="none", + ) + trainer = GRPOTrainer( + model=model_id, + reward_funcs=reward_func, + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check that the params have changed + # Because of the way the tiny models are initialized, the gradient does not flow properly through the + # vision parts of the model, so we skip them. Ideally, we should fix the init of these models. + params_to_skip = ( + "model.vision_tower.", + "model.multi_modal_projector.", + "model.vision_model.", + "model.visual.", + "model.image_newline", + ) + for n, param in previous_trainable_params.items(): + if n.startswith(params_to_skip): + continue + new_param = trainer.model.get_parameter(n) + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." + + @pytest.mark.parametrize( + "model_id", + [ + "trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration", + ], + ) + @require_vision + def test_training_vlm_beta_non_zero(self, model_id): + dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train") + + def reward_func(completions, **kwargs): + """Reward function that rewards longer completions.""" + return [float(len(completion[0]["content"])) for completion in completions] + + training_args = GRPOConfig( + output_dir=self.tmp_dir, + beta=0.1, # set beta to non-zero value to test the case where the reference model is used + learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=8, # reduce the completion length to reduce memory usage + report_to="none", + ) + trainer = GRPOTrainer( + model=model_id, + reward_funcs=reward_func, + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check that the params have changed + # Because of the way the tiny models are initialized, the gradient does not flow properly through the + # vision parts of the model, so we skip them. Ideally, we should fix the init of these models. + params_to_skip = ("model.visual.",) + for n, param in previous_trainable_params.items(): + if n.startswith(params_to_skip): + continue + new_param = trainer.model.get_parameter(n) + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." + + @pytest.mark.parametrize( + "model_id", + [ + "trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration", + ], + ) + @require_vision + @require_peft + def test_training_vlm_peft(self, model_id): + model = AutoModelForImageTextToText.from_pretrained(model_id, dtype="float32") + base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()] + dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train") + + def reward_func(completions, **kwargs): + """Reward function that rewards longer completions.""" + return [float(len(completion[0]["content"])) for completion in completions] + + training_args = GRPOConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=8, # reduce the completion length to reduce memory usage + report_to="none", + ) + trainer = GRPOTrainer( + model=model, + reward_funcs=reward_func, + args=training_args, + train_dataset=dataset, + peft_config=LoraConfig(target_modules=["q_proj", "v_proj"]), + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check that the peft params have changed and the base model params have not changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + if n in base_param_names: # We expect the base model params to be the same + torch.testing.assert_close(param, new_param), f"Parameter {n} has changed." + elif "base_layer" not in n: # We expect the peft params to be different (except for the base layer) + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed." + + @pytest.mark.parametrize( + "model_id", + [ + "trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration", + ], + ) + @require_vision + def test_training_vlm_and_importance_sampling(self, model_id): + dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train") + + def reward_func(completions, **kwargs): + """Reward function that rewards longer completions.""" + return [float(len(completion[0]["content"])) for completion in completions] + + training_args = GRPOConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=8, # reduce the completion length to reduce memory usage + steps_per_generation=2, # increase the steps per generation to trigger IS + report_to="none", + ) + trainer = GRPOTrainer( + model=model_id, + reward_funcs=reward_func, + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check that the params have changed + # Because of the way the tiny models are initialized, the gradient does not flow properly through the + # vision parts of the model, so we skip them. Ideally, we should fix the init of these models. + params_to_skip = ("model.visual.",) + for n, param in previous_trainable_params.items(): + if n.startswith(params_to_skip): + continue + new_param = trainer.model.get_parameter(n) + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." + + @pytest.mark.xfail( + Version(transformers.__version__) >= Version("5.0.0") and not is_liger_kernel_available(min_version="0.6.5"), + reason="Blocked by upstream liger-kernel bug (linkedin/Liger-Kernel#960); " + "fixed by linkedin/Liger-Kernel#966 but not yet released (>0.6.4 required)", + strict=True, + ) + @pytest.mark.parametrize( + "model_id", + [ + "trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration", + ], + ) + @require_vision + @require_liger_kernel + def test_training_vlm_and_liger(self, model_id): + dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train") + + def reward_func(completions, **kwargs): + """Reward function that rewards longer completions.""" + return [float(len(completion[0]["content"])) for completion in completions] + + training_args = GRPOConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=8, # reduce the completion length to reduce memory usage + use_liger_kernel=True, # enable Liger kernel + report_to="none", + ) + trainer = GRPOTrainer( + model=model_id, + reward_funcs=reward_func, + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check that the params have changed + # Because of the way the tiny models are initialized, the gradient does not flow properly through the + # vision parts of the model, so we skip them. Ideally, we should fix the init of these models. + params_to_skip = ("model.visual.",) + for n, param in previous_trainable_params.items(): + if n.startswith(params_to_skip): + continue + new_param = trainer.model.get_parameter(n) + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." + + @pytest.mark.parametrize( + "model_id", + [ + "trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration", + "trl-internal-testing/tiny-Gemma3ForConditionalGeneration", + ], + ) + @require_vision + @require_vllm + @pytest.mark.skip(reason="We should add a mock for the vLLM server.") + def test_training_vlm_and_vllm(self, model_id) -> None: + dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train") + + def reward_func(completions, **kwargs): + """Reward function that rewards longer completions.""" + return [float(len(completion[0]["content"])) for completion in completions] + + training_args = GRPOConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates + per_device_train_batch_size=3, + num_generations=3, + max_completion_length=8, + report_to="none", + use_vllm=True, + vllm_mode="server", + ) + trainer = GRPOTrainer( + model=model_id, + reward_funcs=reward_func, + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." + + @pytest.mark.parametrize( + "model_id", + [ + "trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration", + ], + ) + @require_vision + def test_training_vlm_multi_image(self, model_id): + dataset = load_dataset("trl-internal-testing/zen-multi-image", "conversational_prompt_only", split="train") + + def reward_func(completions, **kwargs): + """Reward function that rewards longer completions.""" + return [float(len(completion[0]["content"])) for completion in completions] + + training_args = GRPOConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=8, # reduce the completion length to reduce memory usage + report_to="none", + ) + trainer = GRPOTrainer( + model=model_id, + reward_funcs=reward_func, + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check that the params have changed + # Because of the way the tiny models are initialized, the gradient does not flow properly through the + # vision parts of the model, so we skip them. Ideally, we should fix the init of these models. + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." + + def test_training_sequence_importance_sampling(self): + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + training_args = GRPOConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=8, # reduce the completion length to reduce memory usage + num_iterations=2, # the importance sampling weights won't be 0 in this case + importance_sampling_level="sequence", + report_to="none", + ) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." + + def test_training_with_chat_template_kwargs(self): + dataset = load_dataset("trl-internal-testing/zen", "conversational_prompt_only", split="train") + + training_args = GRPOConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates + per_device_train_batch_size=3, + num_generations=3, + max_completion_length=8, + report_to="none", + chat_template_kwargs={"enable_thinking": False}, + ) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen3ForCausalLM", + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." + + @pytest.mark.xfail( + condition=Version(transformers.__version__) < Version("5.0.0"), + reason="Tool parsing is not supported in transformers versions below 5.0.0", + strict=True, + ) + @require_jmespath + @pytest.mark.parametrize("tools", [[multiply_tool], [async_multiply_tool]]) + def test_training_with_tools(self, tools: list[Callable]): + # In this test, we define a simple tool that multiplies two integers. Regardless of the input prompt, + # the model will generate 3 completions, 2 of which will be valid tool calls. Among the 2 tool calls, one will + # succeed and the other will fail (because of a wrong argument name). + + dataset = load_dataset("trl-internal-testing/zen", "conversational_prompt_only", split="train") + + training_args = GRPOConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates + per_device_train_batch_size=3, + num_generations=3, + max_completion_length=128, + report_to="none", + ) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen3MoeForCausalLM", + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + tools=tools, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + tool_name = tools[0].__name__ + + def fake_generate(input_ids, **kwargs): + if input_ids.shape[0] == 3: # first call + # fmt: off + if tool_name == "multiply_tool": + completion_ids = torch.tensor( + [ + # '\n{"name": "multiply_tool", "arguments": {"a": 3, "b": 4}}\n<|im_end|>' + [151657, 198, 4913, 606, 788, 330, 64648, 22785, 497, 330, 16370, 788, 5212, 64, 788, 220, 18, 11, 330, 65, 788, 220, 19, 11248, 151658, 151645], + # an invalid tool call with wrong argument name + # '\n{"name": "multiply_tool", "arguments": {"a": 3, "c": 4}}\n<|im_end|>' + [151657, 198, 4913, 606, 788, 330, 64648, 22785, 497, 330, 16370, 788, 5212, 64, 788, 220, 18, 11, 330, 66, 788, 220, 19, 11248, 151658, 151645], + # "I don't know any tool<|im_end|>" + [40, 1513, 944, 1414, 894, 5392, 151645, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643], + ], + device=input_ids.device, + ) + elif tool_name == "async_multiply_tool": + completion_ids = torch.tensor( + [ + # '\n{"name": "async_multiply_tool", "arguments": {"a": 3, "b": 4}}\n<|im_end|>' + [151657, 198, 4913, 606, 788, 330, 7692, 93054, 22785, 497, 330, 16370, 788, 5212, 64, 788, 220, 18, 11, 330, 65, 788, 220, 19, 11248, 151658, 151645], + # an invalid tool call with wrong argument name + # '\n{"name": "async_multiply_tool", "arguments": {"a": 3, "c": 4}}\n<|im_end|>' + [151657, 198, 4913, 606, 788, 330, 7692, 93054, 22785, 497, 330, 16370, 788, 5212, 64, 788, 220, 18, 11, 330, 66, 788, 220, 19, 11248, 151658, 151645], + # "I don't know any tool<|im_end|>" + [40, 1513, 944, 1414, 894, 5392, 151645, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643], + ], + device=input_ids.device, + ) + # fmt: on + else: # second call will only have two inputs in the batch, because two examples have a tool call. + completion_ids = torch.tensor( + [ + # 'Done!<|im_end|>' + [17453, 0, 151645], + # 'Done!<|im_end|>' + [17453, 0, 151645], + ], + device=input_ids.device, + ) + return torch.cat([input_ids, completion_ids], dim=-1) + + with patch.object(trainer.model, "generate", side_effect=fake_generate): + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + assert trainer.state.log_history[-1]["tools/call_frequency"] is not None + assert trainer.state.log_history[-1]["tools/call_frequency"] == pytest.approx(2 / 3) + assert trainer.state.log_history[-1]["tools/failure_frequency"] is not None + assert trainer.state.log_history[-1]["tools/failure_frequency"] == pytest.approx(1 / 2) + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." + + @pytest.mark.xfail( + condition=Version(transformers.__version__) < Version("5.0.0"), + reason="Tool parsing is not supported in transformers versions below 5.0.0", + strict=True, + ) + @require_jmespath + def test_training_with_malformed_tool_calls(self): + dataset = load_dataset("trl-internal-testing/zen", "conversational_prompt_only", split="train") + + training_args = GRPOConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates + per_device_train_batch_size=3, + num_generations=3, + max_completion_length=128, + report_to="none", + ) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen3MoeForCausalLM", + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + tools=[multiply_tool], + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + def fake_generate(input_ids, **kwargs): + # If input_ids.shape[0] < 3, it means that it's a second call, which should not happen here + assert input_ids.shape[0] == 3 + # fmt: off + completion_ids = torch.tensor( + [ + # '\n{"arguments": {"a": 3, "b": 4}}\n<|im_end|>' + [151657, 198, 4913, 16370, 788, 5212, 64, 788, 220, 18, 11, 330, 65, 788, 220, 19, 11248, 151658, 151645], + # '\n{"name": "multiply_tool"}\n<|im_end|>' + [151657, 198, 4913, 606, 788, 330, 64648, 22785, 16707, 151658, 151645, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643], + # "\n{}\n<|im_end|>" + [151657, 198, 16094, 151658, 151645, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643], + ], + device=input_ids.device, + ) + # fmt: on + return torch.cat([input_ids, completion_ids], dim=-1) + + with patch.object(trainer.model, "generate", side_effect=fake_generate): + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." + + def test_mismatched_reward_processing_classes_length(self): + """Test that mismatched length between reward_funcs and reward_processing_classes raises error.""" + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + # Use two reward models + reward_models = [ + "trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + "trl-internal-testing/tiny-Qwen3ForSequenceClassification", + ] + + # Create a single processing class (tokenizer) + single_processing_class = AutoTokenizer.from_pretrained( + "trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5" + ) + + training_args = GRPOConfig(output_dir=self.tmp_dir, report_to="none") + + with pytest.raises(ValueError, match="must match"): + GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs=reward_models, + reward_processing_classes=single_processing_class, # only one, but need two + args=training_args, + train_dataset=dataset, + ) + + def test_correct_reward_processing_classes_list(self): + """Test that correct list of reward_processing_classes works properly.""" + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + # Use two reward models + reward_models = [ + "trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + "trl-internal-testing/tiny-Qwen3ForSequenceClassification", + ] + + # Create processing classes + processing_class1 = AutoTokenizer.from_pretrained( + "trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5" + ) + processing_class2 = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen3ForSequenceClassification") + + training_args = GRPOConfig(output_dir=self.tmp_dir, report_to="none") + + # Correct list length should work + correct_processing_classes = [processing_class1, processing_class2] + + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs=reward_models, + reward_processing_classes=correct_processing_classes, + args=training_args, + train_dataset=dataset, + ) + + assert len(trainer.reward_processing_classes) == len(reward_models) + + def test_single_reward_model_with_single_processing_class(self): + """Test that single reward model with single processing class works.""" + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + # Use single reward model + reward_model = "trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5" + + # Create a single processing class (tokenizer) + single_processing_class = AutoTokenizer.from_pretrained( + "trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5" + ) + + training_args = GRPOConfig(output_dir=self.tmp_dir, report_to="none") + + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs=reward_model, + reward_processing_classes=single_processing_class, # single object for single reward model + args=training_args, + train_dataset=dataset, + ) + + assert len(trainer.reward_processing_classes) == 1 + assert trainer.reward_processing_classes[0] == single_processing_class + + +@pytest.mark.slow +@require_torch_accelerator +class TestGRPOTrainerSlow(TrlTestCase): + def setup_method(self): + self.train_dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + self.eval_dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="test") + self.max_length = 128 + + def teardown_method(self): + gc.collect() + backend_empty_cache(torch_device) + gc.collect() + + @pytest.mark.parametrize( + "model_name", + [ + "trl-internal-testing/tiny-LlamaForCausalLM-3.2", + "trl-internal-testing/tiny-MistralForCausalLM-0.2", + ], + ) + @require_liger_kernel + def test_training_with_liger_grpo_kernel(self, model_name): + training_args = GRPOConfig( + output_dir=self.tmp_dir, + per_device_train_batch_size=3, + num_generations=3, + use_liger_kernel=True, + max_completion_length=self.max_length, + report_to="none", + logging_strategy="no", + ) + + model = AutoModelForCausalLM.from_pretrained(model_name, dtype="float32") + tokenizer = AutoTokenizer.from_pretrained(model_name) + tokenizer.pad_token = tokenizer.eos_token if tokenizer.pad_token is None else tokenizer.pad_token + + trainer = GRPOTrainer( + model=model, + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=self.train_dataset, + eval_dataset=self.eval_dataset, + processing_class=tokenizer, + ) + from liger_kernel.chunked_loss import LigerFusedLinearGRPOLoss + + assert isinstance(trainer.liger_grpo_loss, LigerFusedLinearGRPOLoss) + + previous_trainable_params = {n: param.clone() for n, param in model.named_parameters()} + + trainer.train() + + for n, param in previous_trainable_params.items(): + new_param = model.get_parameter(n) + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." + + release_memory(model, trainer) + + @pytest.mark.parametrize( + "model_name", + [ + "trl-internal-testing/tiny-LlamaForCausalLM-3.2", + "trl-internal-testing/tiny-MistralForCausalLM-0.2", + ], + ) + @require_liger_kernel + @require_peft + def test_training_with_liger_grpo_kernel_and_peft(self, model_name): + from peft import LoraConfig, TaskType + + training_args = GRPOConfig( + output_dir=self.tmp_dir, + per_device_train_batch_size=3, + num_generations=3, + use_liger_kernel=True, + max_completion_length=self.max_length, + report_to="none", + logging_strategy="no", + ) + + model = AutoModelForCausalLM.from_pretrained(model_name, dtype="float32") + tokenizer = AutoTokenizer.from_pretrained(model_name) + tokenizer.pad_token = tokenizer.eos_token if tokenizer.pad_token is None else tokenizer.pad_token + + # Configure PEFT with LoRA + peft_config = LoraConfig( + task_type=TaskType.CAUSAL_LM, + inference_mode=False, + r=8, + lora_alpha=32, + lora_dropout=0.1, + target_modules=["q_proj", "v_proj"], + ) + + trainer = GRPOTrainer( + model=model, + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=self.train_dataset, + eval_dataset=self.eval_dataset, + processing_class=tokenizer, + peft_config=peft_config, + ) + from liger_kernel.chunked_loss import LigerFusedLinearGRPOLoss + + assert isinstance(trainer.liger_grpo_loss, LigerFusedLinearGRPOLoss) + + # Verify PEFT adapter is properly initialized + from peft import PeftModel + + assert isinstance(trainer.model, PeftModel), "Model should be wrapped with PEFT" + + # Store adapter weights before training + previous_trainable_params = { + n: param.clone() for n, param in trainer.model.named_parameters() if param.requires_grad + } + assert len(previous_trainable_params) > 0, "No trainable parameters found in PEFT model" + + trainer.train() + + # Verify adapter weights have changed after training + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." + + release_memory(model, trainer) + + @pytest.mark.parametrize( + "model_name", + [ + "trl-internal-testing/tiny-LlamaForCausalLM-3.2", + "trl-internal-testing/tiny-MistralForCausalLM-0.2", + ], + ) + def test_training_with_transformers_paged(self, model_name): + """Test that training works with transformers paged implementation (requires GPU).""" + if Version(transformers.__version__) < Version("4.57.0"): + pytest.xfail("Bug in transformers solved in GH#40692, released in 4.57.0.") + training_args = GRPOConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=8, # reduce the completion length to reduce memory usage + use_transformers_paged=True, # Enable transformers paged implementation + report_to="none", + logging_strategy="no", + ) + + model = AutoModelForCausalLM.from_pretrained(model_name, dtype="float32") + + trainer = GRPOTrainer( + model=model, + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=self.train_dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in model.named_parameters()} + + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = model.get_parameter(n) + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." + + release_memory(model, trainer) + + @pytest.mark.parametrize( + "model_name", + [ + "HuggingFaceTB/SmolVLM-Instruct", # Only test the smaller model to avoid OOM + ], + ) + @require_kernels + @require_ampere_or_newer # Flash attention 2 requires Ampere or newer GPUs + @require_bitsandbytes + @require_peft + def test_vlm_training(self, model_name): + """ + Test VLM training with aggressive memory optimization. + + This test uses multiple memory reduction techniques: + - 4-bit quantization with double quantization + - LoRA with very low rank (r=4) + - Minimal batch size (1) with gradient accumulation + - Small images (64x64 instead of 224x224) + - Short sequences (max_completion_length=8) + - Only 4 training samples + - Only 1 training step + - Gradient checkpointing and bfloat16 + """ + + # Create processor once outside the data generator + processor = AutoProcessor.from_pretrained(model_name, use_fast=True, padding_side="left") + conversation = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": "What is in the image?"}, + ], + }, + ] + prompt = processor.apply_chat_template(conversation, add_generation_prompt=True) + + def data_gen(num_samples): + for _ in range(num_samples): + yield { + "prompt": prompt, + "image": np.random.uniform(low=0.0, high=255.0, size=(64, 64, 3)).astype( + np.uint8 + ), # Much smaller images + } + + dataset = Dataset.from_generator( + data_gen, gen_kwargs={"num_samples": 4}, features=Features(image=Image(), prompt=Value(dtype="string")) + ) + # reduce memory requirements as much as possible + quantization_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype="bfloat16", + bnb_4bit_quant_type="nf4", + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_storage="bfloat16", + ) + model = AutoModelForImageTextToText.from_pretrained( + model_name, + attn_implementation="kernels-community/flash-attn2", + dtype="float32", + device_map=get_kbit_device_map(), + quantization_config=quantization_config, + ) + + def reward_func(prompts, completions, **kwargs): + # simple nonsensical reward + return [-((len(c) - 25) ** 2) + 100 for c in completions] + + training_args = GRPOConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates + per_device_train_batch_size=1, # Minimal batch size + gradient_accumulation_steps=2, # Maintain effective batch size + num_generations=2, + max_completion_length=8, # Much shorter completions + bf16=True, # Use bfloat16 precision + max_steps=1, # Only do 1 training step to save time and memory + report_to="none", + logging_strategy="no", + ) + lora_config = LoraConfig( + task_type="CAUSAL_LM", + r=4, # Much lower rank for minimal memory + lora_alpha=8, # Reduced alpha proportionally + lora_dropout=0.1, + target_modules=["q_proj", "v_proj"], # Minimal target modules + # For VLM models, we typically want to freeze the vision encoder + # and only adapt the language model parameters + modules_to_save=None, + ) + + try: + trainer = GRPOTrainer( + model=model, + processing_class=processor, + reward_funcs=[reward_func], + args=training_args, + train_dataset=dataset, + peft_config=lora_config, + ) + + assert isinstance(trainer.model, PeftModel) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check that LoRA parameters have changed + # For VLM models, we're more permissive about which parameters can change + lora_params_changed = False + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + if "lora" in n.lower(): # LoRA parameters should change + if not torch.equal(param, new_param): + lora_params_changed = True + + # At least some LoRA parameters should have changed during training + assert lora_params_changed, "No LoRA parameters were updated during training." + + except torch.OutOfMemoryError as e: + pytest.skip(f"Skipping VLM training test due to insufficient GPU memory: {e}") + except Exception as e: + # Check for other memory-related errors + if any(keyword in str(e).lower() for keyword in ["memory", "cuda", "out of memory", "insufficient"]): + pytest.skip(f"Skipping VLM training test due to hardware constraints: {e}") + else: + raise + + release_memory(model, trainer) + + @require_vllm + @require_bitsandbytes + @require_peft + def test_vlm_processor_vllm_colocate_mode(self): + """ + Test that VLM processors work with vLLM in colocate mode. + + This test uses multiple memory optimization techniques to ensure it runs on limited hardware: + - LoRA (Low-Rank Adaptation) with minimal rank (r=4) + - 4-bit quantization with BitsAndBytesConfig + - Gradient checkpointing + - bfloat16 precision + - Minimal batch sizes and sequence lengths + - Very low GPU memory utilization (5%) + """ + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + config = GRPOConfig( + output_dir=self.tmp_dir, + per_device_train_batch_size=1, # Minimal batch size + gradient_accumulation_steps=2, # Make effective batch size 2, divisible by num_generations + num_generations=2, + max_completion_length=4, # Very short completions to reduce memory + use_vllm=True, # Enable vLLM + vllm_mode="colocate", # Use colocate mode to avoid server dependency + vllm_gpu_memory_utilization=0.05, # Use minimal GPU memory (5%) + bf16=True, # Use bfloat16 to reduce memory + report_to="none", + logging_strategy="no", + ) + + # Create a VLM processor + processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM-Instruct", use_fast=True, padding_side="left") + + # Verify processor has both required attributes for VLM detection + assert hasattr(processor, "tokenizer") + assert hasattr(processor, "image_processor") + + def dummy_reward_func(completions, **kwargs): + return [1.0] * len(completions) + + # Use LoRA configuration for memory efficiency + lora_config = LoraConfig( + r=4, # Very low rank for minimal memory + lora_alpha=8, + target_modules=["q_proj", "v_proj"], # Minimal target modules + lora_dropout=0.1, + bias="none", + task_type="CAUSAL_LM", + ) + + # Use 4-bit quantization for further memory reduction + quantization_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype=torch.bfloat16, + bnb_4bit_quant_type="nf4", + bnb_4bit_use_double_quant=True, + ) + + original_env = {} + required_env_vars = { + "RANK": "0", + "LOCAL_RANK": "0", + "WORLD_SIZE": "1", + "LOCAL_WORLD_SIZE": "1", + "MASTER_ADDR": "localhost", + "MASTER_PORT": "12355", + } + + for key, value in required_env_vars.items(): + original_env[key] = os.environ.get(key) + os.environ[key] = value + + try: + # Test VLM processor with vLLM colocate mode + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + try: + # Load model with quantization for memory efficiency + model = AutoModelForCausalLM.from_pretrained( + "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + quantization_config=quantization_config, + dtype=torch.bfloat16, + ) + + trainer = GRPOTrainer( + model=model, + reward_funcs=dummy_reward_func, + args=config, + train_dataset=dataset, + processing_class=processor, # VLM processor + peft_config=lora_config, # Use LoRA for memory efficiency + ) + + # Should detect VLM processor correctly and allow vLLM + assert trainer.use_vllm, "vLLM should be enabled for VLM processors in colocate mode" + assert trainer.vllm_mode == "colocate", "Should use colocate mode" + + # Check if signature columns were set properly + if trainer._signature_columns is not None: + # Should include 'image' in signature columns for VLM processors + assert "image" in trainer._signature_columns, ( + "Should include 'image' in signature columns for VLM" + ) + + # Should not emit any warnings about VLM incompatibility + incompatibility_warnings = [ + str(w_item.message) + for w_item in w + if "does not support VLMs" in str(w_item.message) + or "not compatible" in str(w_item.message).lower() + ] + assert len(incompatibility_warnings) == 0, ( + f"Should not emit VLM incompatibility warnings, but got: {incompatibility_warnings}" + ) + + # Test passes if we get this far without exceptions + + except Exception as e: + # If vLLM fails to initialize due to hardware constraints or other issues, that's expected + if any( + keyword in str(e).lower() + for keyword in [ + "outofmemoryerror", + "cuda", + "memory", + "insufficient", + "no such device", + "free memory", + "gpu memory utilization", + "decrease gpu memory", + ] + ): + pytest.skip(f"Skipping vLLM colocate test due to hardware constraints: {e}") + elif "KeyError" in str(e) and "RANK" in str(e): + pytest.skip(f"Skipping vLLM colocate test due to environment setup issues: {e}") + elif "ValueError" in str(e) and "memory" in str(e).lower(): + pytest.skip(f"Skipping vLLM colocate test due to memory constraints: {e}") + else: + raise + finally: + # Restore original environment variables + for key, original_value in original_env.items(): + if original_value is None: + os.environ.pop(key, None) + else: + os.environ[key] = original_value + + release_memory(model, trainer) + + @require_vllm + def test_training_vllm(self): + """Test that training works with vLLM for generation.""" + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + training_args = GRPOConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=8, # reduce the completion length to reduce memory usage + report_to="none", + logging_strategy="no", + use_vllm=True, + ) + + try: + trainer = GRPOTrainer( + model="Qwen/Qwen2.5-0.5B-Instruct", # tiny models are too small for vLLM + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." + + except Exception as e: + # If vLLM fails to initialize due to hardware constraints or other issues, that's expected + if any( + keyword in str(e).lower() + for keyword in [ + "outofmemoryerror", + "cuda", + "memory", + "insufficient", + "no such device", + "free memory", + "gpu memory utilization", + "decrease gpu memory", + ] + ): + pytest.skip(f"Skipping vLLM training test due to hardware constraints: {e}") + elif "KeyError" in str(e) and "RANK" in str(e): + pytest.skip(f"Skipping vLLM training test due to environment setup issues: {e}") + elif "ValueError" in str(e) and "memory" in str(e).lower(): + pytest.skip(f"Skipping vLLM training test due to memory constraints: {e}") + else: + raise + + release_memory(trainer.model, trainer) diff --git a/ICL/RL/trl_source/tests/test_model_utils.py b/ICL/RL/trl_source/tests/test_model_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7e3fde991ddcf4469b3efccb78bb412e1da1e8a6 --- /dev/null +++ b/ICL/RL/trl_source/tests/test_model_utils.py @@ -0,0 +1,34 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from transformers import AutoModelForCausalLM + +from trl.models.utils import disable_gradient_checkpointing + + +class TestDisableGradientCheckpointing: + def test_when_disabled(self): + model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") + assert model.is_gradient_checkpointing is False + with disable_gradient_checkpointing(model): + assert model.is_gradient_checkpointing is False + assert model.is_gradient_checkpointing is False + + def test_when_enabled(self): + model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") + model.gradient_checkpointing_enable() + assert model.is_gradient_checkpointing is True + with disable_gradient_checkpointing(model): + assert model.is_gradient_checkpointing is False + assert model.is_gradient_checkpointing is True diff --git a/ICL/RL/trl_source/tests/test_reward_trainer.py b/ICL/RL/trl_source/tests/test_reward_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..c1bcd97b713a78841dfe7dfec80225f7809ac6cf --- /dev/null +++ b/ICL/RL/trl_source/tests/test_reward_trainer.py @@ -0,0 +1,904 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pathlib + +import pytest +import torch +from datasets import load_dataset +from transformers import AutoModelForSequenceClassification, AutoTokenizer +from transformers.utils import is_peft_available + +from trl import RewardConfig, RewardTrainer +from trl.trainer.reward_trainer import DataCollatorForPreference + +from .testing_utils import TrlTestCase, require_peft + + +if is_peft_available(): + from peft import LoraConfig, get_peft_model + + +class TestDataCollatorForPreference(TrlTestCase): + def test_basic_padding(self): + """Test basic padding functionality without completion masks.""" + collator = DataCollatorForPreference(pad_token_id=0) + examples = [ + {"chosen_input_ids": [1, 2, 3], "rejected_input_ids": [4, 5]}, + {"chosen_input_ids": [6, 7], "rejected_input_ids": [8]}, + ] + + result = collator(examples) + + torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [6, 7, 0], [4, 5, 0], [8, 0, 0]])) + torch.testing.assert_close( + result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0], [1, 1, 0], [1, 0, 0]]) + ) + + def test_pad_to_multiple_of(self): + """Test padding to multiple of specified value.""" + collator = DataCollatorForPreference(pad_token_id=0, pad_to_multiple_of=4) + examples = [ + {"chosen_input_ids": [1, 2, 3], "rejected_input_ids": [4, 5]}, + {"chosen_input_ids": [6, 7], "rejected_input_ids": [8]}, + ] + + result = collator(examples) + + torch.testing.assert_close( + result["input_ids"], torch.tensor([[1, 2, 3, 0], [6, 7, 0, 0], [4, 5, 0, 0], [8, 0, 0, 0]]) + ) + torch.testing.assert_close( + result["attention_mask"], torch.tensor([[1, 1, 1, 0], [1, 1, 0, 0], [1, 1, 0, 0], [1, 0, 0, 0]]) + ) + + def test_single_example(self): + """Test collator with a single example.""" + collator = DataCollatorForPreference(pad_token_id=0) + examples = [{"chosen_input_ids": [1, 2, 3], "rejected_input_ids": [4, 5]}] + + result = collator(examples) + + torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [4, 5, 0]])) + torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0]])) + + def test_different_pad_token_id(self): + """Test with different pad token ID.""" + collator = DataCollatorForPreference(pad_token_id=999) + examples = [ + {"chosen_input_ids": [1, 2, 3], "rejected_input_ids": [4, 5]}, + {"chosen_input_ids": [6, 7], "rejected_input_ids": [8]}, + ] + + result = collator(examples) + + torch.testing.assert_close( + result["input_ids"], torch.tensor([[1, 2, 3], [6, 7, 999], [4, 5, 999], [8, 999, 999]]) + ) + torch.testing.assert_close( + result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0], [1, 1, 0], [1, 0, 0]]) + ) + + def test_collate_with_margin(self): + collator = DataCollatorForPreference(pad_token_id=0) + examples = [ + {"chosen_input_ids": [1, 2, 3], "rejected_input_ids": [4, 5], "margin": 0.1}, + {"chosen_input_ids": [6, 7], "rejected_input_ids": [8], "margin": 0.2}, + ] + + result = collator(examples) + + torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [6, 7, 0], [4, 5, 0], [8, 0, 0]])) + torch.testing.assert_close( + result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0], [1, 1, 0], [1, 0, 0]]) + ) + torch.testing.assert_close(result["margin"], torch.tensor([0.1, 0.2])) + + +class TestRewardTrainer(TrlTestCase): + def test_raises_error_when_model_num_labels_not_one(self): + """Test that RewardTrainer raises ValueError when model doesn't have num_labels=1.""" + model = AutoModelForSequenceClassification.from_pretrained( + "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + dtype="float32", + # num_labels=2, # Defaults to 2 num_labels for causal models + ) + + training_args = RewardConfig(output_dir=self.tmp_dir, report_to="none") + with pytest.raises(ValueError, match=r"reward models require `num_labels=1`"): + RewardTrainer(model=model, args=training_args) + + @pytest.mark.parametrize( + "model_id", + [ + "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + "trl-internal-testing/tiny-Qwen3MoeForCausalLM", + "trl-internal-testing/tiny-LlamaForCausalLM-3.2", + ], + ) + def test_train(self, model_id): + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train") + + # Initialize the trainer + training_args = RewardConfig(output_dir=self.tmp_dir, report_to="none") + trainer = RewardTrainer(model=model_id, args=training_args, train_dataset=dataset) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" + + @pytest.mark.parametrize( + "config_name", + [ + "standard_preference", + "conversational_preference", + "standard_implicit_prompt_preference", + "conversational_implicit_prompt_preference", + ], + ) + def test_train_dataset_types(self, config_name): + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", config_name, split="train") + + # Initialize the trainer + training_args = RewardConfig(output_dir=self.tmp_dir, report_to="none") + trainer = RewardTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + args=training_args, + train_dataset=dataset, + ) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" + + def test_train_model(self): + # Instantiate the model + model = AutoModelForSequenceClassification.from_pretrained( + "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + num_labels=1, # required for reward models + dtype="float32", + ) + + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train") + + # Initialize the trainer + training_args = RewardConfig(output_dir=self.tmp_dir, report_to="none") + trainer = RewardTrainer(model=model, args=training_args, train_dataset=dataset) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" + + def test_train_from_sequence_classification_model(self): + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train") + + # Initialize the trainer + training_args = RewardConfig(output_dir=self.tmp_dir, report_to="none") + trainer = RewardTrainer( + model="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + ) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" + + def test_train_model_dtype(self): + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train") + + # Initialize the trainer + training_args = RewardConfig( + output_dir=self.tmp_dir, + model_init_kwargs={"dtype": torch.float16}, + learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates + report_to="none", + ) + trainer = RewardTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + args=training_args, + train_dataset=dataset, + ) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check the params have changed + for n, param in previous_trainable_params.items(): + # For some reasonn model.layers.0.input_layernorm.weight doesn't change in GitHub Actions but does + # locally. We ignore this parameter for now + if "layernorm" in n: + continue + new_param = trainer.model.get_parameter(n) + # Check the torch dtype + assert new_param.dtype == torch.float16 + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" + + @require_peft + def test_train_dense_with_peft_config(self): + # Get the base model parameter names + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + model = AutoModelForSequenceClassification.from_pretrained(model_id, dtype="float32") + base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()] + + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train") + + # Initialize the trainer + training_args = RewardConfig(output_dir=self.tmp_dir, report_to="none") + + trainer = RewardTrainer( + model=model_id, + args=training_args, + train_dataset=dataset, + peft_config=LoraConfig(), + ) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check the peft params have changed and the base model params have not changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + if n in base_param_names: # We expect the base model parameters to be the same + torch.testing.assert_close(param, new_param), f"Parameter {n} has changed" + elif "base_layer" not in n: # We expect the peft parameters to be different (except for the base layer) + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" + + @require_peft + def test_train_moe_with_peft_config(self): + # Get the base model parameter names + model_id = "trl-internal-testing/tiny-Qwen3MoeForCausalLM" + model = AutoModelForSequenceClassification.from_pretrained(model_id, dtype="float32") + base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()] + + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train") + + # Initialize the trainer + training_args = RewardConfig(output_dir=self.tmp_dir, report_to="none") + + trainer = RewardTrainer( + model=model_id, + args=training_args, + train_dataset=dataset, + peft_config=LoraConfig(target_modules=["up_proj", "down_proj", "score"]), + ) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check the peft params have changed and the base model params have not changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + if n in base_param_names: # We expect the base model parameters to be the same + torch.testing.assert_close(param, new_param), f"Parameter {n} has changed" + elif "base_layer" not in n: # We expect the peft parameters to be different (except for the base layer) + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" + + @require_peft + def test_train_peft_model(self): + # Get the base model + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + model = AutoModelForSequenceClassification.from_pretrained( + model_id, + num_labels=1, # required for reward models + dtype="float32", + ) + + # Get the base model parameter names + base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()] + + # Turn the model into a peft model + lora_config = LoraConfig() + model = get_peft_model(model, lora_config) + + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train") + + # Initialize the trainer + training_args = RewardConfig(output_dir=self.tmp_dir, report_to="none") + trainer = RewardTrainer(model=model, args=training_args, train_dataset=dataset) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check the peft params have changed and the base model params have not changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + if n in base_param_names: # We expect the base model parameters to be the same + torch.testing.assert_close(param, new_param), f"Parameter {n} has changed" + elif "base_layer" not in n: # We expect the peft parameters to be different (except for the base layer) + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" + + # In practice, this test is the same as `test_train_dense_with_peft_config`, since gradient checkpointing is + # enabled by default in `RewardTrainer`. We keep it as a regression guard: if the default ever changes, we still + # explicitly test PEFT + gradient checkpointing, which has caused issues in the past. + @require_peft + def test_train_with_peft_config_and_gradient_checkpointing(self): + # Get the base model parameter names + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + model = AutoModelForSequenceClassification.from_pretrained(model_id, dtype="float32") + base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()] + + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train") + + # Initialize the trainer + training_args = RewardConfig(output_dir=self.tmp_dir, gradient_checkpointing=True, report_to="none") + + trainer = RewardTrainer( + model=model_id, + args=training_args, + train_dataset=dataset, + peft_config=LoraConfig(), + ) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check the peft params have changed and the base model params have not changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + if n in base_param_names: # We expect the base model parameters to be the same + torch.testing.assert_close(param, new_param), f"Parameter {n} has changed" + elif "base_layer" not in n: # We expect the peft parameters to be different (except for the base layer) + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" + + @pytest.mark.parametrize("use_reentrant", [True, False]) + @require_peft + def test_train_with_peft_config_and_gradient_checkpointing_reentrant(self, use_reentrant): + # Get the base model parameter names + model_id = "trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5" + model = AutoModelForSequenceClassification.from_pretrained(model_id, dtype="float32") + base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()] + + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train") + + # Initialize the trainer + training_args = RewardConfig( + output_dir=self.tmp_dir, + gradient_checkpointing=True, + gradient_checkpointing_kwargs={"use_reentrant": use_reentrant}, + report_to="none", + ) + + trainer = RewardTrainer( + model=model_id, + args=training_args, + train_dataset=dataset, + peft_config=LoraConfig(), + ) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check the peft params have changed and the base model params have not changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + if n in base_param_names: # We expect the base model parameters to be the same + torch.testing.assert_close(param, new_param), f"Parameter {n} has changed" + elif "base_layer" not in n: # We expect the peft parameters to be different (except for the base layer) + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" + + def test_train_with_pretokenized_data(self): + # Get the dataset + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + tokenizer = AutoTokenizer.from_pretrained(model_id) + dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train") + + def tokenize_example(example): + return { + "chosen_input_ids": tokenizer(example["chosen"]).input_ids, + "rejected_input_ids": tokenizer(example["rejected"]).input_ids, + } + + # Apply tokenization + tokenized_dataset = dataset.map(tokenize_example, remove_columns=["chosen", "rejected"]) + + # Initialize the trainer + training_args = RewardConfig(output_dir=self.tmp_dir, report_to="none") + trainer = RewardTrainer(model=model_id, args=training_args, train_dataset=tokenized_dataset) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" + + def test_train_with_iterable_dataset(self): + # Get the dataset + dataset = load_dataset( + "trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train", streaming=True + ) + + # Initialize the trainer + training_args = RewardConfig(output_dir=self.tmp_dir, max_steps=3, report_to="none") + trainer = RewardTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + args=training_args, + train_dataset=dataset, + ) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" + + def test_train_with_chat_template_kwargs(self): + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "conversational_implicit_prompt_preference", split="train") + + # Initialize the trainer + training_args = RewardConfig(output_dir=self.tmp_dir, report_to="none") + + tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") + # The following template is a simplified version of the Qwen chat template, where an additional argument + # `role_capital` is used to control the capitalization of roles. + tokenizer.chat_template = '{%- if messages[0]["role"] == "system" -%} {{ "<|im_start|>" + ("SYSTEM" if role_capital else "system") + "\\n" + messages[0]["content"] + "<|im_end|>\\n" }}{%- else -%} {{ "<|im_start|>" + ("SYSTEM" if role_capital else "system") + "\\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\\n" }}{%- endif -%}{%- for message in messages -%} {%- if (message.role == "user") or (message.role == "system" and not loop.first) or (message.role == "assistant" and not message.tool_calls) -%} {{ "<|im_start|>" + (message.role.upper() if role_capital else message.role) + "\\n" + message.content + "<|im_end|>\\n" }} {%- elif message.role == "assistant" -%} {{ "<|im_start|>" + ("ASSISTANT" if role_capital else "assistant") }} {%- if message.content -%} {{ "\\n" + message.content }} {%- endif -%} {{ "<|im_end|>\\n" }} {%- elif message.role == "tool" -%} {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != "tool") -%} {{ "<|im_start|>" + ("USER" if role_capital else "user") }} {%- endif -%} {{ "\\n\\n" + message.content + "\\n" }} {%- if loop.last or (messages[loop.index0 + 1].role != "tool") -%} {{ "<|im_end|>\\n" }} {%- endif -%} {%- endif -%}{%- endfor -%}{%- if add_generation_prompt -%} {{ "<|im_start|>" + ("ASSISTANT" if role_capital else "assistant") + "\\n" }}{%- endif -%}' + + dataset = dataset.add_column( + "chat_template_kwargs", [{"role_capital": bool(i % 2)} for i in range(len(dataset))] + ) + assert "chat_template_kwargs" in dataset.features + + trainer = RewardTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + args=training_args, + train_dataset=dataset, + processing_class=tokenizer, + ) + + # Assert trainer uses the same chat template as tokenizer + assert trainer.processing_class.chat_template == tokenizer.chat_template + + # Assert chat_template is applied + for i in range(2): + role = "SYSTEM" if i else "system" + system_prompt = ( + f"<|im_start|>{role}\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>" + ) + system_prompt_ids = trainer.processing_class(system_prompt)["input_ids"] + assert trainer.train_dataset[i]["chosen_input_ids"][: len(system_prompt_ids)] == system_prompt_ids + assert trainer.train_dataset[i]["rejected_input_ids"][: len(system_prompt_ids)] == system_prompt_ids + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" + + def test_train_with_set_chat_template_from_model(self): + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "conversational_preference", split="train") + + # Initialize the trainer + training_args = RewardConfig(output_dir=self.tmp_dir, chat_template_path="Qwen/Qwen3-4B", report_to="none") + # trl-internal-testing/tiny-GPTNeoXForCausalLM doesn't have a chat template set by default + trainer = RewardTrainer( + model="trl-internal-testing/tiny-GPTNeoXForCausalLM", + args=training_args, + train_dataset=dataset, + ) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + # RewardTrainer uses a mean-free loss that cancels uniform shifts in output scores. Since GPT-NeoX models + # include a final LayerNorm, its bias consistently receives zero gradient and remains unchanged, so we skip + # this parameter. + if n == "gpt_neox.final_layer_norm.bias": + continue + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" + + def test_train_with_set_chat_template_from_path(self, lazy_shared_datadir): + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "conversational_preference", split="train") + + # Initialize the trainer + training_args = RewardConfig( + output_dir=self.tmp_dir, + chat_template_path=str(lazy_shared_datadir / "template.jinja"), + report_to="none", + ) + # trl-internal-testing/tiny-GPTNeoXForCausalLM doesn't have a chat template set by default + trainer = RewardTrainer( + model="trl-internal-testing/tiny-GPTNeoXForCausalLM", + args=training_args, + train_dataset=dataset, + ) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + # RewardTrainer uses a mean-free loss that cancels uniform shifts in output scores. Since GPT-NeoX models + # include a final LayerNorm, its bias consistently receives zero gradient and remains unchanged, so we skip + # this parameter. + if n == "gpt_neox.final_layer_norm.bias": + continue + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" + + # Check that the template saved in the output directory is the same as the one used for training + template_path = pathlib.Path(self.tmp_dir) / "checkpoint-9" / "chat_template.jinja" + assert template_path.exists(), f"Chat template not found at {template_path}" + + with open(template_path) as f: + template_content = f.read() + with open(training_args.chat_template_path) as f: + original_template_content = f.read() + assert template_content == original_template_content, "Chat template content does not match the original" + + def test_train_toolcall_data(self): + # Get the dataset + dataset = load_dataset("trl-internal-testing/toolcall", "preference", split="train") + + # Initialize the trainer + training_args = RewardConfig(output_dir=self.tmp_dir, report_to="none") + trainer = RewardTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + args=training_args, + train_dataset=dataset, + ) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" + + def test_train_with_eval(self): + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference") + + # Initialize the trainer + training_args = RewardConfig(output_dir=self.tmp_dir, eval_strategy="steps", eval_steps=3, report_to="none") + trainer = RewardTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + args=training_args, + train_dataset=dataset["train"], + eval_dataset=dataset["test"], + ) + + # Train the model + trainer.train() + + # Check that the eval loss is not None + assert trainer.state.log_history[0]["eval_loss"] is not None + + def test_train_with_multiple_eval_dataset(self): + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference") + + # Initialize the trainer + training_args = RewardConfig(output_dir=self.tmp_dir, eval_strategy="steps", eval_steps=3, report_to="none") + trainer = RewardTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + args=training_args, + train_dataset=dataset["train"], + eval_dataset={"data1": dataset["test"], "data2": dataset["test"]}, + ) + # Train the model + trainer.train() + + # Check that the eval losses are not None + assert trainer.state.log_history[-3]["eval_data1_loss"] is not None + assert trainer.state.log_history[-2]["eval_data2_loss"] is not None + + def test_train_with_compute_metrics(self): + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference") + + def dummy_compute_metrics(eval_pred): + return {"my_metric": 0.123} + + # Initialize the trainer + training_args = RewardConfig( + output_dir=self.tmp_dir, + eval_strategy="steps", + eval_steps=3, + report_to="none", + ) + trainer = RewardTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + args=training_args, + train_dataset=dataset["train"], + eval_dataset=dataset["test"], + compute_metrics=dummy_compute_metrics, + ) + + # Train the model + trainer.train() + + # Check that the custom metric is logged + assert trainer.state.log_history[-2]["eval_my_metric"] == 0.123 + + # In practice, this test is the same as `test_train`, since gradient checkpointing is enabled by default in + # `RewardTrainer`. We keep it as a regression guard: if the default ever changes, we still explicitly test gradient + # checkpointing, which has caused issues in the past. + def test_train_with_gradient_checkpointing(self): + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train") + + # Initialize the trainer + training_args = RewardConfig(output_dir=self.tmp_dir, gradient_checkpointing=True, report_to="none") + trainer = RewardTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + args=training_args, + train_dataset=dataset, + ) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" + + @pytest.mark.parametrize("use_reentrant", [True, False]) + def test_train_with_gradient_checkpointing_reentrant(self, use_reentrant): + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train") + + # Initialize the trainer + training_args = RewardConfig( + output_dir=self.tmp_dir, + gradient_checkpointing=True, + gradient_checkpointing_kwargs={"use_reentrant": use_reentrant}, + report_to="none", + ) + trainer = RewardTrainer( + model="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + ) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" + + def test_tag_added(self): + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train") + + # Initialize the trainer + trainer = RewardTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + train_dataset=dataset, + ) + + for tag in ["reward-trainer", "trl"]: + assert tag in trainer.model.model_tags + + @require_peft + def test_tag_added_peft(self): + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train") + + # Initialize the trainer + trainer = RewardTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + train_dataset=dataset, + peft_config=LoraConfig(), + ) + + for tag in ["reward-trainer", "trl"]: + assert tag in trainer.model.model_tags + + def test_train_with_margin(self): + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train") + + def add_margin(example): + # dummy margin based on the length of the chosen summary + return {"margin": len(example["chosen"])} + + dataset = dataset.map(add_margin) + + # Initialize the trainer + training_args = RewardConfig(output_dir=self.tmp_dir, report_to="none") + trainer = RewardTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + args=training_args, + train_dataset=dataset, + ) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" + + def test_train_with_center_rewards_coefficient(self): + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train") + + # Initialize the trainer + training_args = RewardConfig(output_dir=self.tmp_dir, center_rewards_coefficient=0.01, report_to="none") + trainer = RewardTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + args=training_args, + train_dataset=dataset, + ) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" diff --git a/ICL/RL/trl_source/tests/test_rewards.py b/ICL/RL/trl_source/tests/test_rewards.py new file mode 100644 index 0000000000000000000000000000000000000000..2442b3bd2ce0a9e8b71ed3a359a86880fcacc441 --- /dev/null +++ b/ICL/RL/trl_source/tests/test_rewards.py @@ -0,0 +1,192 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from trl.rewards import accuracy_reward, get_soft_overlong_punishment, reasoning_accuracy_reward, think_format_reward + +from .testing_utils import TrlTestCase, require_math_latex + + +class TestThinkFormatReward(TrlTestCase): + def test_valid_format(self): + completions = [ + "This is my reasoning.This is my answer.", # Simple, one-line reasoning + "\nThis is my reasoning.\n\nThis is my answer.", # Multiline reasoning + "\nThis is\nmy reasoning.\n\nThis is my answer.", # Multiline reasoning + "\nThis is my reasoning.\nThis is my answer.", # Reasoning including other tags + "\nThis is my answer.", # Empty reasoning + ] + completions = [[{"content": completion}] for completion in completions] + expected_rewards = [1.0, 1.0, 1.0, 1.0, 1.0] # All should be valid + rewards = think_format_reward(completions) + assert rewards == expected_rewards + + def test_invalid_format(self): + completions = [ + "\nThis is my reasoning.\nThis is my answer.", # No closing + "This is my reasoning.\nThis is my answer.", # No closing + "This is my reasoning. This is my answer.", # No tags + "This is my reasoning.\nThis is my answer.", # No tags + "This is my reasoning.\nThis is my answer.", # No opening + "This is my reasoning.This is my answer.", # No opening + "Thisis my reasoning.\nThis is my answer.", # tag in the middle + "This ismy reasoning.This is my answer.", # Nested tags + "This is\nmy\nreasoning.\nThis is my answer.", # Multiline + ] + completions = [[{"content": completion}] for completion in completions] + expected_rewards = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] # All should be invalid + rewards = think_format_reward(completions) + assert rewards == expected_rewards + + def test_mixed_format(self): + completions = [ + "This is my reasoning.This is my answer.", # Valid + "\nThis is my reasoning.\n\nThis is my answer.", # Valid + "This is my reasoning.\nThis is my answer.", # Invalid + "This is my reasoning. This is my answer.", # Invalid + ] + completions = [[{"content": completion}] for completion in completions] + expected_rewards = [1.0, 1.0, 0.0, 0.0] + rewards = think_format_reward(completions) + assert rewards == expected_rewards + + +class TestSoftOverlongPunishmentReward: + def test_soft_overlong_punishment_short_completion(self): + """Test soft overlong punishment reward function with a short completion.""" + # length 50, with max=100 and soft cache=20, reward should be 0. + reward_fn = get_soft_overlong_punishment(max_completion_len=100, soft_punish_cache=20) + completion_ids = [[1] * 50] # 50 <= 80 + rewards = reward_fn(completion_ids=completion_ids) + assert rewards == [0] + + def test_soft_overlong_punishment_long_completion(self): + """Test soft overlong punishment reward function with a longer than max completion.""" + # 110 > 100, reward should be -1. + reward_fn = get_soft_overlong_punishment(max_completion_len=100, soft_punish_cache=20) + completion_ids = [[1] * 110] + rewards = reward_fn(completion_ids) + assert rewards == [-1] + + def test_soft_overlong_punishment_intermediate_completion(self): + """Test soft overlong punishment reward function for intermediate length completion.""" + reward_fn = get_soft_overlong_punishment(max_completion_len=100, soft_punish_cache=20) + completion_ids = [[1] * 90] # 90 is between 80 and 100 + rewards = reward_fn(completion_ids) + assert round(abs(rewards[0] - -0.5), 4) == 0 + + +class TestAccuracyReward: + @require_math_latex + def test_accuracy_reward_correct_answer(self): + """Test accuracy_reward with a correct answer.""" + completion = [[{"content": r"\boxed{\frac{63}{400}}"}], [{"content": r"\boxed{\frac{63}{400}}"}]] + solution = [r"\frac{63}{400}", "63/400"] + rewards = accuracy_reward(completion, solution) + assert rewards[0] == 1.0 + assert rewards[1] == 1.0 + + @require_math_latex + def test_accuracy_reward_wrong_answer(self): + """Test accuracy_reward with an incorrect answer.""" + completion = [[{"content": r"\boxed{\frac{64}{400}}"}]] + solution = [r"\frac{63}{400}"] + rewards = accuracy_reward(completion, solution) + assert rewards[0] == 0.0 + + @require_math_latex + def test_accuracy_reward_wrong_answer_no_latex(self): + """Test accuracy_reward with an incorrect answer and gold solution with no latex.""" + completion = [[{"content": r"\boxed{3}"}]] + solution = ["6"] + rewards = accuracy_reward(completion, solution) + assert rewards[0] == 0.0 + + @require_math_latex + def test_accuracy_reward_unparsable_gold(self): + """Test accuracy_reward with an unparsable gold solution.""" + completion = [ + [{"content": "Answer is forty two."}], + [{"content": r"Some other content. \boxed{43}."}], + ] + solution = [ + "Answer is forty two.", + "Answer is forty three.", + ] + rewards = accuracy_reward(completion, solution) + assert rewards[0] is None + assert rewards[1] is None + + +class TestReasoningAccuracyReward: + @require_math_latex + def test_correct_answer_yields_unit_reward(self): + completions = [ + [{"content": r" Reasoning content \boxed{\frac{63}{400}}"}], + [{"content": r"Reasoning content \boxed{\frac{63}{400}}"}], + ] + solutions = [r"\frac{63}{400}", r"\frac{63}{400}"] + rewards = reasoning_accuracy_reward(completions, solutions) + assert rewards[0] == 1.0 + assert rewards[1] == 1.0 + + @require_math_latex + def test_correct_answer_with_custom_tags_yields_unit_reward(self): + completions = [ + [{"content": r" Reasoning content \boxed{\frac{63}{400}}"}], + ] + solutions = [ + r"\frac{63}{400}", + ] + rewards = reasoning_accuracy_reward(completions, solutions, reasoning_delimiters=[""]) + assert rewards[0] == 1.0 + + @require_math_latex + def test_incorrect_answer_yields_zero_reward(self): + completion = [[{"content": r" Reasoning content \boxed{\frac{64}{400}}"}]] + solution = [r"\frac{63}{400}"] + rewards = reasoning_accuracy_reward(completion, solution) + assert rewards[0] == 0.0 + + @require_math_latex + def test_correct_answer_in_reasoning_yields_zero_reward(self): + completions = [ + [{"content": r" My answer is \boxed{42} Some other text."}], + [{"content": r" The answer is \boxed{42} Here's a wrong answer: \boxed{43}."}], + ] + solutions = [r"\boxed{42}", r"\boxed{42}"] + rewards = reasoning_accuracy_reward(completions, solutions) + assert rewards[0] == 0.0 + assert rewards[1] == 0.0 + + @require_math_latex + def test_incomplete_reasoning_yields_zero_reward(self): + completions = [ + [{"content": r" Incomplete reasoning without closing tag"}], + [{"content": r"Correct answer \frac{63}{400} but completely missing reasoning content"}], + ] + solutions = [r"\frac{63}{400}", r"\frac{63}{400}"] + rewards = reasoning_accuracy_reward(completions, solutions) + assert rewards[0] == 0.0 + assert rewards[1] == 0.0 + + @require_math_latex + def test_unparsable_gold_solution_yields_none_reward(self): + completions = [ + [{"content": r" Reasoning content \boxed{42}"}], + ] + solutions = [ + "forty two", + ] + rewards = reasoning_accuracy_reward(completions, solutions) + assert rewards[0] is None diff --git a/ICL/RL/trl_source/tests/test_sft_trainer.py b/ICL/RL/trl_source/tests/test_sft_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..21aa2581f13efbd93e42d915e1c2566b28f4878f --- /dev/null +++ b/ICL/RL/trl_source/tests/test_sft_trainer.py @@ -0,0 +1,2125 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import pathlib +from unittest.mock import MagicMock, patch + +import pytest +import torch +import transformers +from accelerate.utils.memory import release_memory +from datasets import load_dataset +from packaging.version import Version +from packaging.version import parse as parse_version +from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments +from transformers.testing_utils import backend_empty_cache, torch_device +from transformers.utils import is_peft_available + +from trl import SFTConfig, SFTTrainer +from trl.trainer.sft_trainer import DataCollatorForLanguageModeling, dft_loss + +from .testing_utils import ( + TrlTestCase, + ignore_warnings, + require_ampere_or_newer, + require_bitsandbytes, + require_kernels, + require_liger_kernel, + require_peft, + require_torch_accelerator, + require_torch_multi_accelerator, + require_vision, +) + + +if is_peft_available(): + import peft + from peft import ( + LoraConfig, + PeftModel, + PrefixTuningConfig, + PromptEncoderConfig, + PromptTuningConfig, + TaskType, + get_peft_model, + ) + + +class TestDFTLoss(TrlTestCase): + def test_dft_loss(self): + batch_size = 2 + seq_len = 3 + vocab_size = 2 + # All tokens have the same probability + logits = torch.fill(torch.empty(batch_size, seq_len, vocab_size), torch.rand(1).item()) + outputs = MagicMock() + outputs.logits = logits + labels = torch.tensor([[1, 0, 0], [0, 1, -100]]) + ce_loss = torch.nn.functional.cross_entropy( + logits.view(-1, vocab_size), labels.view(-1), ignore_index=-100, reduction="mean" + ) + # We need to account for the logits shift operation so we don't consider the first tokens + # in each row of the batch + num_items_in_batch = 3 + # Dft loss + predicted_dft_loss = dft_loss(outputs, labels, num_items_in_batch) + # If we have just two tokens in our vocab and all logits are the same, + # dft scales the ce_loss per token by 0.5. So the dft_loss should be ce_loss/2 + torch.testing.assert_close(ce_loss / 2.0, predicted_dft_loss, atol=1e-4, rtol=1e-4) + + +class TestDataCollatorForLanguageModeling(TrlTestCase): + def test_basic_padding(self): + """Test basic padding functionality without completion masks.""" + collator = DataCollatorForLanguageModeling(pad_token_id=0) + examples = [{"input_ids": [1, 2, 3]}, {"input_ids": [4, 5]}] + + result = collator(examples) + + assert set(result.keys()) == {"input_ids", "attention_mask", "labels"} + torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [4, 5, 0]])) + torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0]])) + torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3], [4, 5, -100]])) + + def test_completion_mask(self): + """Test completion mask functionality.""" + collator = DataCollatorForLanguageModeling(pad_token_id=0) + examples = [ + {"input_ids": [1, 2, 3], "completion_mask": [0, 1, 1]}, + {"input_ids": [4, 5], "completion_mask": [0, 1]}, + ] + + result = collator(examples) + + assert set(result.keys()) == {"input_ids", "attention_mask", "labels"} + torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [4, 5, 0]])) + torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0]])) + torch.testing.assert_close(result["labels"], torch.tensor([[-100, 2, 3], [-100, 5, -100]])) + + def test_completion_only_loss_disabled(self): + """Test behavior when completion_only_loss is disabled.""" + collator = DataCollatorForLanguageModeling(pad_token_id=0, completion_only_loss=False) + examples = [ + {"input_ids": [1, 2, 3], "completion_mask": [0, 1, 1]}, + {"input_ids": [4, 5], "completion_mask": [0, 1]}, + ] + + result = collator(examples) + + # Labels should not be masked when completion_only_loss=False + assert set(result.keys()) == {"input_ids", "attention_mask", "labels"} + torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [4, 5, 0]])) + torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0]])) + torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3], [4, 5, -100]])) + + def test_padding_free_mode(self): + """Test padding-free mode where sequences are concatenated.""" + collator = DataCollatorForLanguageModeling(pad_token_id=0, padding_free=True) + examples = [{"input_ids": [1, 2, 3]}, {"input_ids": [4, 5]}] + + result = collator(examples) + + assert set(result.keys()) == {"input_ids", "position_ids", "labels"} + torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3, 4, 5]])) + torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2, 0, 1]])) + torch.testing.assert_close(result["labels"], torch.tensor([[-100, 2, 3, -100, 5]])) + + def test_padding_free_with_completion_mask(self): + """Test padding-free mode with completion masks.""" + collator = DataCollatorForLanguageModeling(pad_token_id=0, padding_free=True) + examples = [ + {"input_ids": [1, 2, 3], "completion_mask": [0, 0, 1]}, + {"input_ids": [4, 5], "completion_mask": [1, 1]}, + ] + + result = collator(examples) + + assert set(result.keys()) == {"input_ids", "position_ids", "labels"} + torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3, 4, 5]])) + torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2, 0, 1]])) + torch.testing.assert_close(result["labels"], torch.tensor([[-100, -100, 3, -100, 5]])) + + def test_packing(self): + """Test that when using packing with position_ids, attention_mask is dropped with fa2.""" + collator = DataCollatorForLanguageModeling(pad_token_id=0, padding_free=True) + + # Simulate packed sequences with position_ids that restart (typical of BFD packing) + examples = [ + {"input_ids": [1, 2, 3, 4, 5, 6], "seq_lengths": [3, 3]}, + {"input_ids": [7, 8, 9, 10, 11], "seq_lengths": [4, 1]}, + ] + + result = collator(examples) + + assert set(result.keys()) == {"input_ids", "position_ids", "labels"} + torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]])) + torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2, 0, 1, 2, 0, 1, 2, 3, 0]])) + torch.testing.assert_close(result["labels"], torch.tensor([[-100, 2, 3, -100, 5, 6, -100, 8, 9, 10, -100]])) + + def test_pad_to_multiple_of(self): + """Test padding to multiple of specified value.""" + collator = DataCollatorForLanguageModeling(pad_token_id=0, pad_to_multiple_of=4) + examples = [{"input_ids": [1, 2, 3]}, {"input_ids": [4, 5]}] + + result = collator(examples) + + assert set(result.keys()) == {"input_ids", "attention_mask", "labels"} + torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3, 0], [4, 5, 0, 0]])) + torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1, 0], [1, 1, 0, 0]])) + torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3, -100], [4, 5, -100, -100]])) + + def test_pad_to_multiple_of_and_padding_free(self): + """Test padding to multiple of specified value.""" + collator = DataCollatorForLanguageModeling(pad_token_id=0, padding_free=True, pad_to_multiple_of=4) + examples = [{"input_ids": [1, 2, 3]}, {"input_ids": [4, 5]}] + + result = collator(examples) + + assert set(result.keys()) == {"input_ids", "position_ids", "labels"} + torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3, 4, 5, 0, 0, 0]])) + torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2, 0, 1, 0, 0, 0]])) + torch.testing.assert_close(result["labels"], torch.tensor([[-100, 2, 3, -100, 5, -100, -100, -100]])) + + def test_custom_position_ids_but_no_padding_free(self): + """Test that custom position_ids are ignored if padding_free is False.""" + collator = DataCollatorForLanguageModeling(pad_token_id=0) + examples = [{"input_ids": [1, 2, 3], "seq_lengths": [1, 2]}, {"input_ids": [4, 5], "seq_lengths": [2]}] + + result = collator(examples) + + assert set(result.keys()) == {"input_ids", "attention_mask", "labels"} + torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [4, 5, 0]])) + torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0]])) + torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3], [4, 5, -100]])) + + def test_single_example(self): + """Test collator with a single example.""" + collator = DataCollatorForLanguageModeling(pad_token_id=0) + examples = [{"input_ids": [1, 2, 3, 4]}] + + result = collator(examples) + + assert set(result.keys()) == {"input_ids", "attention_mask", "labels"} + torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3, 4]])) + torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1, 1]])) + torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3, 4]])) + + def test_different_pad_token_id(self): + """Test with different pad token ID.""" + collator = DataCollatorForLanguageModeling(pad_token_id=999) + examples = [{"input_ids": [1, 2, 3]}, {"input_ids": [4, 5]}] + + result = collator(examples) + + assert set(result.keys()) == {"input_ids", "attention_mask", "labels"} + torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [4, 5, 999]])) + torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0]])) + torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3], [4, 5, -100]])) + + def test_assistant_masks(self): + """Test handling of assistant masks in examples.""" + collator = DataCollatorForLanguageModeling(pad_token_id=0) + examples = [ + {"input_ids": [1, 2, 3], "assistant_masks": [0, 1, 1]}, + {"input_ids": [4, 5], "assistant_masks": [0, 1]}, + ] + + result = collator(examples) + + torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [4, 5, 0]])) + torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0]])) + torch.testing.assert_close(result["labels"], torch.tensor([[-100, 2, 3], [-100, 5, -100]])) + + def test_single_example_single_doc(self): + batch_seq_lengths = [[5]] + result = DataCollatorForLanguageModeling.get_position_ids_from_packed_seq_lengths(batch_seq_lengths) + assert len(result) == 1 + assert torch.equal(result[0], torch.arange(5)) + + def test_single_example_multiple_docs(self): + batch_seq_lengths = [[3, 2]] + result = DataCollatorForLanguageModeling.get_position_ids_from_packed_seq_lengths(batch_seq_lengths) + assert len(result) == 1 + # First sequence: 0, 1, 2; second sequence: 0, 1 + assert torch.equal(result[0], torch.tensor([0, 1, 2, 0, 1])) + + def test_multiple_examples(self): + batch_seq_lengths = [[2, 2], [3]] + result = DataCollatorForLanguageModeling.get_position_ids_from_packed_seq_lengths(batch_seq_lengths) + assert len(result) == 2 + assert torch.equal(result[0], torch.tensor([0, 1, 0, 1])) + assert torch.equal(result[1], torch.arange(3)) + + +class TestSFTTrainer(TrlTestCase): + def test_init_with_training_arguments(self): + dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") + args = TrainingArguments(output_dir=self.tmp_dir, report_to="none") + SFTTrainer(model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=args, train_dataset=dataset) + + @pytest.mark.parametrize( + "model_id", + [ + "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + "trl-internal-testing/tiny-Qwen3MoeForCausalLM", + "trl-internal-testing/tiny-GptOssForCausalLM", + ], + ) + def test_train(self, model_id): + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") + + # Initialize the trainer + training_args = SFTConfig(output_dir=self.tmp_dir, report_to="none") + trainer = SFTTrainer(model=model_id, args=training_args, train_dataset=dataset) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" + + # Special case for harmony + def test_train_gpt_oss(self): + # Get the dataset + dataset = load_dataset("trl-internal-testing/harmony", "language_modeling", split="train") + + # Initialize the trainer + training_args = SFTConfig(output_dir=self.tmp_dir, report_to="none") + trainer = SFTTrainer( + model="trl-internal-testing/tiny-GptOssForCausalLM", args=training_args, train_dataset=dataset + ) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" + + def test_train_model(self): + # Instantiate the model + model = AutoModelForCausalLM.from_pretrained( + "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + dtype="float32", + ) + + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") + + # Initialize the trainer + training_args = SFTConfig(output_dir=self.tmp_dir, report_to="none") + trainer = SFTTrainer(model=model, args=training_args, train_dataset=dataset) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" + + def test_train_dft_loss(self): + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling") + + # Initialize the trainer + training_args = SFTConfig( + output_dir=self.tmp_dir, + loss_type="dft", + learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates + report_to="none", + eval_strategy="steps", + eval_steps=3, + ) + trainer = SFTTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + args=training_args, + train_dataset=dataset["train"], + eval_dataset=dataset["test"], + ) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" + + def test_train_moe_model_with_aux_loss(self): + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") + + # Initialize the trainer + training_args = SFTConfig( + output_dir=self.tmp_dir, + report_to="none", + model_init_kwargs={"output_router_logits": True}, + ) + trainer = SFTTrainer( + model="trl-internal-testing/tiny-Qwen3MoeForCausalLM", args=training_args, train_dataset=dataset + ) + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss and aux loss are not None + assert trainer.state.log_history[-1]["train_loss"] is not None + assert trainer.state.log_history[-1]["aux_loss"] is not None + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" + + def test_train_with_formatting_func(self): + # Dummy formatting function + def formatting_prompts_func(example): + chosen, rejected = example["chosen"], example["rejected"] + return f"### Chosen: {chosen}\n### Rejected: {rejected}" + + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train") + + # Initialize the trainer + training_args = SFTConfig(output_dir=self.tmp_dir, report_to="none") + trainer = SFTTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + args=training_args, + train_dataset=dataset, + formatting_func=formatting_prompts_func, + ) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" + + def test_train_model_dtype(self): + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") + + # Initialize the trainer + training_args = SFTConfig( + output_dir=self.tmp_dir, + model_init_kwargs={"dtype": torch.float16}, + learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates + report_to="none", + ) + trainer = SFTTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=training_args, train_dataset=dataset + ) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check the params have changed + for n, param in previous_trainable_params.items(): + # For some reasonn model.layers.0.input_layernorm.weight doesn't change in GitHub Actions but does + # locally. We ignore this parameter for now + if "layernorm" in n: + continue + new_param = trainer.model.get_parameter(n) + # Check the torch dtype + assert new_param.dtype == torch.float16 + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" + + @require_peft + def test_train_dense_with_peft_config_lora(self): + # Get the base model parameter names + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + model = AutoModelForCausalLM.from_pretrained(model_id, dtype="float32") + base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()] + + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") + + # Initialize the trainer + training_args = SFTConfig(output_dir=self.tmp_dir, report_to="none") + + trainer = SFTTrainer( + model=model_id, + args=training_args, + train_dataset=dataset, + peft_config=LoraConfig(), + ) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check the peft params have changed and the base model params have not changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + if n in base_param_names: # We expect the base model parameters to be the same + torch.testing.assert_close(param, new_param), f"Parameter {n} has changed" + elif "base_layer" not in n: # We expect the peft parameters to be different (except for the base layer) + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" + + @pytest.mark.parametrize( + "peft_type", + [ + "prompt_tuning", + "prefix_tuning", + "prompt_encoder", + ], + ) + @require_peft + def test_train_with_peft_config_prompt_tuning(self, peft_type): + # Get the base model parameter names + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + model = AutoModelForCausalLM.from_pretrained(model_id, dtype="float32") + base_param_names = [f"base_model.{n}" for n, _ in model.named_parameters()] + + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") + + # Initialize the trainer, p-tuning doesn't support gradient checkpointing + training_args = SFTConfig(bf16=False, output_dir=self.tmp_dir, report_to="none", gradient_checkpointing=False) + if peft_type == "prompt_tuning": + peft_config = PromptTuningConfig( + task_type=TaskType.CAUSAL_LM, + num_virtual_tokens=4, + tokenizer_name_or_path="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + ) + elif peft_type == "prefix_tuning": + if parse_version(peft.__version__) <= Version("0.17.1"): + pytest.xfail( + "Prefix tuning with device_map='auto' is broken in peft 0.17.1 and below. See " + "https://github.com/huggingface/peft/issues/2821" + ) + peft_config = PrefixTuningConfig( + task_type=TaskType.CAUSAL_LM, + num_virtual_tokens=4, + ) + elif peft_type == "prompt_encoder": + peft_config = PromptEncoderConfig( + task_type=TaskType.CAUSAL_LM, + num_virtual_tokens=4, + encoder_hidden_size=model.config.hidden_size, # This will be overwritten below + ) + trainer = SFTTrainer( + model=model_id, + args=training_args, + train_dataset=dataset, + peft_config=peft_config, + ) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check the peft params have changed and the base model params have not changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + if n in base_param_names: # We expect the base model parameters to be the same + torch.testing.assert_close(param, new_param), f"Parameter {n} has changed" + else: # We expect the peft parameters to be different + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" + + @require_peft + def test_train_moe_with_peft_config(self): + # Get the base model parameter names + model_id = "trl-internal-testing/tiny-GptOssForCausalLM" + model = AutoModelForCausalLM.from_pretrained(model_id, dtype="float32") + base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()] + + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") + + # Initialize the trainer + training_args = SFTConfig(output_dir=self.tmp_dir, report_to="none") + + trainer = SFTTrainer( + model=model_id, + args=training_args, + train_dataset=dataset, + peft_config=LoraConfig(target_parameters=["mlp.experts.down_proj", "mlp.experts.gate_up_proj"]), + ) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check the peft params have changed and the base model params have not changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + if n in base_param_names: # We expect the base model parameters to be the same + torch.testing.assert_close(param, new_param), f"Parameter {n} has changed" + elif "base_layer" not in n: # We expect the peft parameters to be different (except for the base layer) + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" + + @require_peft + def test_train_peft_model(self): + # Get the base model + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + model = AutoModelForCausalLM.from_pretrained(model_id, dtype="float32") + + # Get the base model parameter names + base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()] + + # Turn the model into a peft model + lora_config = LoraConfig() + model = get_peft_model(model, lora_config) + + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") + + # Initialize the trainer + training_args = SFTConfig(output_dir=self.tmp_dir, report_to="none") + trainer = SFTTrainer(model=model, args=training_args, train_dataset=dataset) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check the peft params have changed and the base model params have not changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + if n in base_param_names: # We expect the base model parameters to be the same + torch.testing.assert_close(param, new_param), f"Parameter {n} has changed" + elif "base_layer" not in n: # We expect the peft parameters to be different (except for the base layer) + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" + + # In practice, this test is the same as `test_train_dense_with_peft_config_lora`, since gradient checkpointing is + # enabled by default in `SFTTrainer`. We keep it as a regression guard: if the default ever changes, we still + # explicitly test PEFT + gradient checkpointing, which has caused issues in the past. + @require_peft + def test_train_with_peft_config_and_gradient_checkpointing(self): + # Get the base model parameter names + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + model = AutoModelForCausalLM.from_pretrained(model_id, dtype="float32") + base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()] + + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") + + # Initialize the trainer + training_args = SFTConfig(output_dir=self.tmp_dir, gradient_checkpointing=True, report_to="none") + + trainer = SFTTrainer( + model=model_id, + args=training_args, + train_dataset=dataset, + peft_config=LoraConfig(), + ) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check the peft params have changed and the base model params have not changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + if n in base_param_names: # We expect the base model parameters to be the same + torch.testing.assert_close(param, new_param), f"Parameter {n} has changed" + elif "base_layer" not in n: # We expect the peft parameters to be different (except for the base layer) + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" + + @pytest.mark.parametrize("use_reentrant", [True, False]) + @require_peft + def test_train_with_peft_config_and_gradient_checkpointing_reentrant(self, use_reentrant): + # Get the base model parameter names + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + model = AutoModelForCausalLM.from_pretrained(model_id, dtype="float32") + base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()] + + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") + + # Initialize the trainer + training_args = SFTConfig( + output_dir=self.tmp_dir, + gradient_checkpointing=True, + gradient_checkpointing_kwargs={"use_reentrant": use_reentrant}, + report_to="none", + ) + + trainer = SFTTrainer( + model=model_id, + args=training_args, + train_dataset=dataset, + peft_config=LoraConfig(), + ) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check the peft params have changed and the base model params have not changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + if n in base_param_names: # We expect the base model parameters to be the same + torch.testing.assert_close(param, new_param), f"Parameter {n} has changed" + elif "base_layer" not in n: # We expect the peft parameters to be different (except for the base layer) + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" + + @require_liger_kernel + def test_train_with_liger(self): + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") + + # Initialize the trainer + training_args = SFTConfig(output_dir=self.tmp_dir, use_liger_kernel=True, report_to="none") + trainer = SFTTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=training_args, train_dataset=dataset + ) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" + + @require_torch_accelerator + @require_liger_kernel + def test_compute_loss_skip_logits_on_eval_without_metrics_with_liger(self): + dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train[:1]") + + training_args = SFTConfig( + output_dir=self.tmp_dir, + use_liger_kernel=False, + report_to="none", + max_length=8, + bf16=False, + ) + trainer = SFTTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + args=training_args, + train_dataset=dataset, + compute_metrics=None, + ) + trainer.args.use_liger_kernel = True + trainer.model.eval() + + captured = {} + + def mock_super_compute_loss(model, inputs, return_outputs=False, num_items_in_batch=None): + captured["skip_logits"] = inputs.get("skip_logits") + dummy_loss = torch.tensor(1.0, requires_grad=True) + dummy_outputs = MagicMock() + dummy_outputs.token_accuracy = None + dummy_outputs.logits = torch.randn(1, 5, trainer.model.config.vocab_size) + return (dummy_loss, dummy_outputs) + + inputs = { + "input_ids": torch.tensor([[1, 2, 3, 4, 5]]), + "labels": torch.tensor([[1, 2, 3, 4, 5]]), + "attention_mask": torch.tensor([[1, 1, 1, 1, 1]]), + } + + with patch("trl.trainer.sft_trainer.BaseTrainer.compute_loss", side_effect=mock_super_compute_loss): + trainer.compute_loss(trainer.model, inputs) + + assert captured["skip_logits"] is True + + @require_torch_accelerator + @require_liger_kernel + def test_predict_does_not_skip_logits_with_liger(self): + dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train[:1]") + + training_args = SFTConfig( + output_dir=self.tmp_dir, + use_liger_kernel=False, + report_to="none", + max_length=8, + bf16=False, + ) + trainer = SFTTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + args=training_args, + train_dataset=dataset, + compute_metrics=None, + ) + trainer.args.use_liger_kernel = True + trainer.model.eval() + + captured = {} + + def mock_super_compute_loss(model, inputs, return_outputs=False, num_items_in_batch=None): + captured["skip_logits"] = inputs.get("skip_logits") + dummy_loss = torch.tensor(1.0, requires_grad=True) + dummy_outputs = (dummy_loss, torch.randn(1, 5, trainer.model.config.vocab_size)) + return (dummy_loss, dummy_outputs) + + with patch("trl.trainer.sft_trainer.BaseTrainer.compute_loss", side_effect=mock_super_compute_loss): + trainer.predict(trainer.train_dataset) + + assert captured["skip_logits"] is False + + def test_train_with_non_chatml_conversational_data(self): + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling", split="train") + + # Rename role/content to from/value to ensure SFT works with non-chatML conversational data + def rename_fields(example: list[dict]): + return {"conversations": [{"from": m["role"], "value": m["content"]} for m in example["messages"]]} + + dataset = dataset.map(rename_fields, remove_columns="messages") + + # Initialize the trainer + training_args = SFTConfig(output_dir=self.tmp_dir, report_to="none") + trainer = SFTTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=training_args, train_dataset=dataset + ) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" + + def test_train_with_pretokenized_data(self): + # Get the dataset + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + tokenizer = AutoTokenizer.from_pretrained(model_id) + dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") + + def tokenize_example(example): + return tokenizer(example["text"]) + + # Apply tokenization + tokenized_dataset = dataset.map(tokenize_example, remove_columns=["text"]) + + # Initialize the trainer + training_args = SFTConfig(output_dir=self.tmp_dir, report_to="none") + trainer = SFTTrainer(model=model_id, args=training_args, train_dataset=tokenized_dataset) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" + + def test_train_with_iterable_dataset(self): + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train", streaming=True) + + # Initialize the trainer + training_args = SFTConfig(output_dir=self.tmp_dir, max_steps=3, report_to="none") + trainer = SFTTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=training_args, train_dataset=dataset + ) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" + + @require_kernels + @require_ampere_or_newer # Flash attention 2 requires Ampere or newer GPUs + def test_train_padding_free(self): + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") + + # Initialize the trainer + training_args = SFTConfig( + output_dir=self.tmp_dir, + padding_free=True, + model_init_kwargs={"attn_implementation": "kernels-community/flash-attn2"}, + bf16=True, # flash_attention_2 only supports bf16 and fp16 + report_to="none", + ) + trainer = SFTTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=training_args, train_dataset=dataset + ) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" + + @pytest.mark.parametrize("packing_strategy", ["bfd", "wrapped"]) + @ignore_warnings(message="You are using packing, but the attention implementation is not.*", category=UserWarning) + @ignore_warnings(message="Padding-free training is enabled, but the attention.*", category=UserWarning) + def test_train_packing(self, packing_strategy): + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") + + # Initialize the trainer + training_args = SFTConfig( + output_dir=self.tmp_dir, packing=True, packing_strategy=packing_strategy, max_length=10, report_to="none" + ) + trainer = SFTTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=training_args, train_dataset=dataset + ) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" + + @ignore_warnings(message="You are using packing, but the attention implementation is not.*", category=UserWarning) + @ignore_warnings(message="Padding-free training is enabled, but the attention.*", category=UserWarning) + def test_eval_packing(self): + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling") + + # Initialize the trainer + training_args = SFTConfig( + output_dir=self.tmp_dir, + packing=True, + max_length=64, + report_to="none", + ) + trainer = SFTTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + args=training_args, + train_dataset=dataset["train"], + eval_dataset=dataset["test"], + ) + + # Check the number of sequences in train and eval datasets + num_train_seqs = sum(len(x) for x in trainer.train_dataset["seq_lengths"]) + num_eval_seqs = sum(len(x) for x in trainer.eval_dataset["seq_lengths"]) + assert num_train_seqs == 17 # we should still have 17 seqs + assert num_eval_seqs == 2 # we should still have 2 seqs + + # Check that all sequences are shorter than the max length + assert all(sum(x) <= 64 for x in trainer.train_dataset["seq_lengths"]) + assert all(sum(x) <= 64 for x in trainer.eval_dataset["seq_lengths"]) + + # Check the number of sequences in train and eval datasets + assert len(trainer.train_dataset["input_ids"]) == 3 # w/ this dataset, we end up with 46 seqs + assert len(trainer.eval_dataset["input_ids"]) == 1 # w/ this dataset, we end up with 6 seqs + + @ignore_warnings(message="You are using packing, but the attention implementation is not.*", category=UserWarning) + @ignore_warnings(message="Padding-free training is enabled, but the attention.*", category=UserWarning) + def test_only_train_packing(self): + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling") + + # Initialize the trainer + training_args = SFTConfig( + output_dir=self.tmp_dir, + packing=True, + eval_packing=False, + max_length=64, + report_to="none", + ) + trainer = SFTTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + args=training_args, + train_dataset=dataset["train"], + eval_dataset=dataset["test"], + ) + + # Check the number of sequences in train dataset + num_train_seqs = sum(len(x) for x in trainer.train_dataset["seq_lengths"]) + assert num_train_seqs == 17 # we should still have 17 seqs + + # We expect eval dataset not having "seq_lengths" as eval_packing is False + assert "seq_lengths" not in trainer.eval_dataset + + # Check that all sequences are shorter than the max length + assert all(sum(x) <= 64 for x in trainer.train_dataset["seq_lengths"]) + + # Check the number of sequences in train and eval datasets + assert len(trainer.train_dataset["input_ids"]) == 3 # w/ this dataset, we end up with 46 seqs + assert len(trainer.eval_dataset["input_ids"]) == 2 # w/ this dataset, we end up with 6 seqs + + def test_train_with_chat_template_kwargs(self): + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling", split="train") + + # Initialize the trainer + training_args = SFTConfig(output_dir=self.tmp_dir, report_to="none") + + tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") + # The following template is a simplified version of the Qwen chat template, where an additional argument + # `role_capital` is used to control the capitalization of roles. + tokenizer.chat_template = '{%- if messages[0]["role"] == "system" -%} {{ "<|im_start|>" + ("SYSTEM" if role_capital else "system") + "\\n" + messages[0]["content"] + "<|im_end|>\\n" }}{%- else -%} {{ "<|im_start|>" + ("SYSTEM" if role_capital else "system") + "\\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\\n" }}{%- endif -%}{%- for message in messages -%} {%- if (message.role == "user") or (message.role == "system" and not loop.first) or (message.role == "assistant" and not message.tool_calls) -%} {{ "<|im_start|>" + (message.role.upper() if role_capital else message.role) + "\\n" + message.content + "<|im_end|>\\n" }} {%- elif message.role == "assistant" -%} {{ "<|im_start|>" + ("ASSISTANT" if role_capital else "assistant") }} {%- if message.content -%} {{ "\\n" + message.content }} {%- endif -%} {{ "<|im_end|>\\n" }} {%- elif message.role == "tool" -%} {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != "tool") -%} {{ "<|im_start|>" + ("USER" if role_capital else "user") }} {%- endif -%} {{ "\\n\\n" + message.content + "\\n" }} {%- if loop.last or (messages[loop.index0 + 1].role != "tool") -%} {{ "<|im_end|>\\n" }} {%- endif -%} {%- endif -%}{%- endfor -%}{%- if add_generation_prompt -%} {{ "<|im_start|>" + ("ASSISTANT" if role_capital else "assistant") + "\\n" }}{%- endif -%}' + + dataset = dataset.add_column( + "chat_template_kwargs", [{"role_capital": bool(i % 2)} for i in range(len(dataset))] + ) + assert "chat_template_kwargs" in dataset.features + + trainer = SFTTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + args=training_args, + train_dataset=dataset, + processing_class=tokenizer, + ) + + # Assert trainer uses the same chat template as tokenizer + assert trainer.processing_class.chat_template == tokenizer.chat_template + + # Assert chat_template is applied + for i in range(2): + role = "SYSTEM" if i else "system" + system_prompt = ( + f"<|im_start|>{role}\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>" + ) + system_prompt_ids = trainer.processing_class(system_prompt)["input_ids"] + assert trainer.train_dataset[i]["input_ids"][: len(system_prompt_ids)] == system_prompt_ids + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" + + def test_train_assistant_only(self): + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling", split="train") + + # Initialize the trainer + training_args = SFTConfig(output_dir=self.tmp_dir, assistant_only_loss=True, report_to="none") + trainer = SFTTrainer( + model="trl-internal-testing/tiny-Qwen3ForCausalLM", args=training_args, train_dataset=dataset + ) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" + + def test_train_completion_only(self): + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "conversational_prompt_completion", split="train") + + # Initialize the trainer + training_args = SFTConfig(output_dir=self.tmp_dir, completion_only_loss=True, report_to="none") + trainer = SFTTrainer( + model="trl-internal-testing/tiny-Qwen3ForCausalLM", args=training_args, train_dataset=dataset + ) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" + + def test_train_completion_only_harmony(self): + # Get the dataset + dataset = load_dataset("trl-internal-testing/harmony", "prompt_completion", split="train") + + # Initialize the trainer + training_args = SFTConfig(output_dir=self.tmp_dir, completion_only_loss=True, report_to="none") + trainer = SFTTrainer( + model="trl-internal-testing/tiny-GptOssForCausalLM", args=training_args, train_dataset=dataset + ) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" + + def test_train_assistant_only_and_completion_only(self): + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "conversational_prompt_completion", split="train") + + # To test this case, we need to add user messages in the completion (they'll be masked in the loss) + def add_to_completion(example): + example["completion"].append(example["prompt"][0]) + example["completion"].append(example["completion"][0]) + return example + + dataset = dataset.map(add_to_completion) + + # Initialize the trainer + training_args = SFTConfig( + output_dir=self.tmp_dir, assistant_only_loss=True, completion_only_loss=True, report_to="none" + ) + trainer = SFTTrainer( + model="trl-internal-testing/tiny-Qwen3ForCausalLM", args=training_args, train_dataset=dataset + ) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" + + def test_train_assistant_only_iterable_dataset(self): + # Get the dataset + dataset = load_dataset( + "trl-internal-testing/zen", "conversational_language_modeling", split="train", streaming=True + ) + + # Initialize the trainer + training_args = SFTConfig(output_dir=self.tmp_dir, assistant_only_loss=True, max_steps=3, report_to="none") + trainer = SFTTrainer( + model="trl-internal-testing/tiny-Qwen3ForCausalLM", args=training_args, train_dataset=dataset + ) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" + + def test_train_with_set_chat_template_from_model(self): + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling", split="train") + + # Initialize the trainer + training_args = SFTConfig(output_dir=self.tmp_dir, chat_template_path="Qwen/Qwen3-4B", report_to="none") + # trl-internal-testing/tiny-GPTNeoXForCausalLM doesn't have a chat template set by default + trainer = SFTTrainer( + model="trl-internal-testing/tiny-GPTNeoXForCausalLM", args=training_args, train_dataset=dataset + ) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" + + def test_train_with_set_chat_template_from_path(self, lazy_shared_datadir): + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling", split="train") + + # Initialize the trainer + training_args = SFTConfig( + output_dir=self.tmp_dir, + chat_template_path=str(lazy_shared_datadir / "template.jinja"), + report_to="none", + ) + # trl-internal-testing/tiny-GPTNeoXForCausalLM doesn't have a chat template set by default + trainer = SFTTrainer( + model="trl-internal-testing/tiny-GPTNeoXForCausalLM", args=training_args, train_dataset=dataset + ) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" + + # Check that the template saved in the output directory is the same as the one used for training + template_path = pathlib.Path(self.tmp_dir) / "checkpoint-9" / "chat_template.jinja" + assert template_path.exists(), f"Chat template not found at {template_path}" + + with open(template_path) as f: + template_content = f.read() + with open(training_args.chat_template_path) as f: + original_template_content = f.read() + assert template_content == original_template_content, "Chat template content does not match the original" + + def test_train_toolcall_data(self): + # Get the dataset + dataset = load_dataset("trl-internal-testing/toolcall", "language_modeling", split="train") + + # Initialize the trainer + training_args = SFTConfig(output_dir=self.tmp_dir, report_to="none") + trainer = SFTTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=training_args, train_dataset=dataset + ) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" + + def test_train_with_eval(self): + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling") + + # Initialize the trainer + training_args = SFTConfig(output_dir=self.tmp_dir, eval_strategy="steps", eval_steps=3, report_to="none") + trainer = SFTTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + args=training_args, + train_dataset=dataset["train"], + eval_dataset=dataset["test"], + ) + + # Train the model + trainer.train() + + # Check that the eval loss is not None + assert trainer.state.log_history[0]["eval_loss"] is not None + + def test_train_with_multiple_eval_dataset(self): + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling") + + # Initialize the trainer + training_args = SFTConfig(output_dir=self.tmp_dir, eval_strategy="steps", eval_steps=3, report_to="none") + trainer = SFTTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + args=training_args, + train_dataset=dataset["train"], + eval_dataset={"data1": dataset["test"], "data2": dataset["test"]}, + ) + # Train the model + trainer.train() + + # Check that the eval losses are not None + assert trainer.state.log_history[-3]["eval_data1_loss"] is not None + assert trainer.state.log_history[-2]["eval_data2_loss"] is not None + + def test_train_with_compute_metrics(self): + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling") + + def dummy_compute_metrics(eval_pred): + return {"my_metric": 0.123} + + # Initialize the trainer + training_args = SFTConfig( + output_dir=self.tmp_dir, + eval_strategy="steps", + eval_steps=3, + report_to="none", + ) + trainer = SFTTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + args=training_args, + train_dataset=dataset["train"], + eval_dataset=dataset["test"], + compute_metrics=dummy_compute_metrics, + ) + + # Train the model + trainer.train() + + # Check that the custom metric is logged + assert trainer.state.log_history[-2]["eval_my_metric"] == 0.123 + + # In practice, this test is the same as `test_train`, since gradient checkpointing is enabled by default in + # `SFTTrainer`. We keep it as a regression guard: if the default ever changes, we still explicitly test gradient + # checkpointing, which has caused issues in the past. + def test_train_with_gradient_checkpointing(self): + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") + + # Initialize the trainer + training_args = SFTConfig(output_dir=self.tmp_dir, gradient_checkpointing=True, report_to="none") + trainer = SFTTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=training_args, train_dataset=dataset + ) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" + + @pytest.mark.parametrize("use_reentrant", [True, False]) + def test_train_with_gradient_checkpointing_reentrant(self, use_reentrant): + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") + + # Initialize the trainer + training_args = SFTConfig( + output_dir=self.tmp_dir, + gradient_checkpointing=True, + gradient_checkpointing_kwargs={"use_reentrant": use_reentrant}, + report_to="none", + ) + trainer = SFTTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=training_args, train_dataset=dataset + ) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" + + def test_tag_added(self): + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") + + # Initialize the trainer + trainer = SFTTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + train_dataset=dataset, + ) + + for tag in ["sft", "trl"]: + assert tag in trainer.model.model_tags + + @require_peft + def test_tag_added_peft(self): + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") + + # Initialize the trainer + trainer = SFTTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + train_dataset=dataset, + peft_config=LoraConfig(), + ) + + for tag in ["sft", "trl"]: + assert tag in trainer.model.model_tags + + @pytest.mark.parametrize( + "model_id", + [ + "trl-internal-testing/tiny-Gemma3ForConditionalGeneration", + # "trl-internal-testing/tiny-Idefics2ForConditionalGeneration", high memory peak, skipped for now + # "trl-internal-testing/tiny-Idefics3ForConditionalGeneration", high memory peak, skipped for now + "trl-internal-testing/tiny-LlavaForConditionalGeneration", + "trl-internal-testing/tiny-LlavaNextForConditionalGeneration", + "trl-internal-testing/tiny-Qwen2VLForConditionalGeneration", + "trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration", + # "trl-internal-testing/tiny-SmolVLMForConditionalGeneration", seems not to support bf16 properly + pytest.param( + "trl-internal-testing/tiny-Qwen3VLForConditionalGeneration", + marks=[ + pytest.mark.skipif( + Version(transformers.__version__) < Version("4.57.0"), + reason="Qwen3-VL series were introduced in transformers-4.57.0", + ), + pytest.mark.xfail( + Version("5.0.0") <= Version(transformers.__version__) < Version("5.1.0"), + reason="Upstream transformers bug (transformers#43334) in 5.0.x; fixed in 5.1.0", + ), + ], + ), + ], + ) + @require_vision + def test_train_vlm(self, model_id): + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen-image", "conversational_language_modeling", split="train") + + # Initialize the trainer + training_args = SFTConfig( + output_dir=self.tmp_dir, + max_length=None, # for VLMs, truncating can remove image tokens, leading to errors + report_to="none", + ) + trainer = SFTTrainer(model=model_id, args=training_args, train_dataset=dataset) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + # For some reason, these params are not updated. This is probably not related to TRL, but to + # the model itself. We should investigate this further, but for now we just skip these params. + # fmt: off + if ( + model_id == "trl-internal-testing/tiny-Gemma3ForConditionalGeneration" and "model.vision_tower.vision_model.head" in n or + model_id == "trl-internal-testing/tiny-LlavaForConditionalGeneration" and "model.vision_tower.vision_model.post_layernorm" in n or + model_id == "trl-internal-testing/tiny-LlavaForConditionalGeneration" and "vision_tower.vision_model.encoder.layers.1" in n or + model_id == "trl-internal-testing/tiny-LlavaNextForConditionalGeneration" and "model.vision_tower.vision_model.post_layernorm" in n or + model_id == "trl-internal-testing/tiny-LlavaNextForConditionalGeneration" and "vision_tower.vision_model.encoder.layers.1" in n or + model_id == "trl-internal-testing/tiny-Qwen3VLForConditionalGeneration" and "model.visual.deepstack_merger_list" in n + ): + # fmt: on + continue + assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated" + + @pytest.mark.parametrize( + "model_id", + [ + "trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration", + ], + ) + @pytest.mark.xfail( + parse_version(transformers.__version__) < parse_version("4.57.0"), + reason="Mixing text-only and image+text examples is only supported in transformers >= 4.57.0", + strict=False, + ) + @require_vision + def test_train_vlm_multi_image(self, model_id): + # Get the dataset + dataset = load_dataset( + "trl-internal-testing/zen-multi-image", "conversational_prompt_completion", split="train" + ) + + # Initialize the trainer + training_args = SFTConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates + max_length=None, # for VLMs, truncating can remove image tokens, leading to errors + report_to="none", + ) + trainer = SFTTrainer( + model=model_id, + args=training_args, + train_dataset=dataset, + ) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated" + + @pytest.mark.parametrize( + "model_id", + [ + "trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration", + # Special case for Gemma, as it uses token_type_ids, and we need to ensure they are properly in the collator: + "trl-internal-testing/tiny-Gemma3ForConditionalGeneration", + ], + ) + @require_vision + def test_train_vlm_prompt_completion(self, model_id): + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_completion", split="train") + + # Initialize the trainer + training_args = SFTConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates + max_length=None, # for VLMs, truncating can remove image tokens, leading to errors + report_to="none", + ) + trainer = SFTTrainer( + model=model_id, + args=training_args, + train_dataset=dataset, + ) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated" + + # Gemma 3n uses a timm encoder, making it difficult to create a smaller variant for testing. + # To ensure coverage, we run tests on the full model but mark them as slow to exclude from default runs. + @pytest.mark.slow + @require_vision + @pytest.mark.skip(reason="Model google/gemma-3n-E2B-it is gated and requires HF token") + def test_train_vlm_gemma_3n(self): + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen-image", "conversational_language_modeling", split="train") + + # Initialize the trainer + training_args = SFTConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates + max_length=None, # for VLMs, truncating can remove image tokens, leading to errors + per_device_train_batch_size=1, + model_init_kwargs={"dtype": "bfloat16"}, + report_to="none", + ) + trainer = SFTTrainer(model="google/gemma-3n-E2B-it", args=training_args, train_dataset=dataset) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + if "model.audio_tower" in n or "model.embed_audio" in n: + # The audio embedding parameters are not updated because this dataset contains no audio data + continue + assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated" + + @pytest.mark.parametrize( + "model_id", + [ + "trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration", + ], + ) + @pytest.mark.parametrize( + "dataset_config", + ["conversational_language_modeling", "conversational_prompt_completion", "standard_prompt_completion"], + ) + @require_vision + def test_train_vlm_text_only_data(self, model_id, dataset_config): + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", dataset_config, split="train") + + # Initialize the trainer + training_args = SFTConfig(output_dir=self.tmp_dir, report_to="none") + trainer = SFTTrainer( + model=model_id, + args=training_args, + train_dataset=dataset, + ) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + if n.startswith("model.visual"): + torch.testing.assert_close(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is updated" + else: + assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated" + + @require_peft + def test_prompt_tuning(self): + """Test that SFT works with Prompt Tuning.""" + dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") + + training_args = SFTConfig(output_dir=self.tmp_dir, report_to="none") + trainer = SFTTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + args=training_args, + train_dataset=dataset, + peft_config=PromptEncoderConfig(task_type=TaskType.CAUSAL_LM, num_virtual_tokens=8), + ) + + # Save initial parameters to check they change during training + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + # Check that training completed successfully + assert trainer.state.log_history[-1]["train_loss"] is not None + assert trainer.state.log_history[-1]["mean_token_accuracy"] is not None + + # Check the peft params have changed and the base model params have not changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + if "base_model" in n: # We expect the base model parameters to be the same + torch.testing.assert_close(param, new_param), f"Parameter {n} has changed" + elif "prompt_encoder" in n: # We expect the peft parameters to be different + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" + else: + raise ValueError(f"Unexpected parameter {n} in model: {trainer.model}") + + @require_peft + @require_bitsandbytes + def test_peft_with_quantization(self): + # Get the base model + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + + quantization_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.float16, + ) + model = AutoModelForCausalLM.from_pretrained( + model_id, + dtype="float32", + quantization_config=quantization_config, + ) + + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") + + # Initialize the trainer with the already configured PeftModel + training_args = SFTConfig(output_dir=self.tmp_dir, learning_rate=0.1, report_to="none") + trainer = SFTTrainer(model=model, args=training_args, train_dataset=dataset, peft_config=LoraConfig()) + + # Save initial parameters to check they change during training + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + # Check that training completed successfully + assert trainer.state.log_history[-1]["train_loss"] is not None + assert trainer.state.log_history[-1]["mean_token_accuracy"] is not None + + # Check the peft params have changed and the base model params have not changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + # In bitsandbytes, bias parameters are automatically cast to the input dtype during the forward pass if + # their dtype doesn’t match. This causes the module to change unexpectedly during the first forward pass of + # the training. To handle this, we cast these specific bias parameters to float32 before comparison. + # https://github.com/bitsandbytes-foundation/bitsandbytes/blob/45553f7392e524eacf400b132cfe01261f6477be/bitsandbytes/nn/modules.py#L518 + # We still need to investigate why the compute dtype ends up being different than for these parameters. + if n in [ + "base_model.model.model.layers.1.self_attn.k_proj.bias", + "base_model.model.model.layers.1.self_attn.q_proj.base_layer.bias", + "base_model.model.model.layers.1.self_attn.v_proj.base_layer.bias", + ]: + param = param.float() + + if "lora" not in n: # We expect the base model parameters to be the same + torch.testing.assert_close(param, new_param), f"Parameter {n} has changed" + elif "lora" in n: # We expect the peft parameters to be different + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" + else: + raise ValueError(f"Unexpected parameter {n} in model: {trainer.model}") + + @require_peft + def test_prompt_tuning_peft_model(self): + """Test that SFT works with Prompt Tuning and a pre-converted PeftModel""" + model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", dtype="float32") + model = get_peft_model(model, PromptEncoderConfig(task_type=TaskType.CAUSAL_LM, num_virtual_tokens=8)) + + dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") + + training_args = SFTConfig(output_dir=self.tmp_dir, report_to="none") + trainer = SFTTrainer(model=model, args=training_args, train_dataset=dataset) + + # Save initial parameters to check they change during training + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + # Check that training completed successfully + assert trainer.state.log_history[-1]["train_loss"] is not None + assert trainer.state.log_history[-1]["mean_token_accuracy"] is not None + + # Check the peft params have changed and the base model params have not changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + if "base_model" in n: # We expect the base model parameters to be the same + torch.testing.assert_close(param, new_param), f"Parameter {n} has changed" + elif "prompt_encoder" in n: # We expect the peft parameters to be different + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" + else: + raise ValueError(f"Unexpected parameter {n} in model: {trainer.model}") + + +@pytest.mark.slow +@require_torch_accelerator +@require_peft +class TestSFTTrainerSlow(TrlTestCase): + def setup_method(self): + self.train_dataset = load_dataset("stanfordnlp/imdb", split="train[:10%]") + self.eval_dataset = load_dataset("stanfordnlp/imdb", split="test[:10%]") + self.max_length = 128 + self.peft_config = LoraConfig( + lora_alpha=16, + lora_dropout=0.1, + r=8, + bias="none", + task_type="CAUSAL_LM", + ) + + def teardown_method(self): + gc.collect() + backend_empty_cache(torch_device) + gc.collect() + + @pytest.mark.parametrize("packing", [True, False]) + @pytest.mark.parametrize( + "model_name", + [ + "trl-internal-testing/tiny-LlamaForCausalLM-3.2", + "trl-internal-testing/tiny-MistralForCausalLM-0.2", + ], + ) + def test_sft_trainer_transformers_mp(self, model_name, packing): + """ + Simply tests if passing a transformers model to `SFTTrainer` loads and runs the trainer as expected in mixed + precision. + """ + training_args = SFTConfig( + output_dir=self.tmp_dir, + logging_strategy="no", + report_to="none", + per_device_train_batch_size=2, + max_steps=10, + fp16=True, # this is sufficient to enable amp + packing=packing, + max_length=self.max_length, + ) + + model = AutoModelForCausalLM.from_pretrained(model_name, dtype="float32") + tokenizer = AutoTokenizer.from_pretrained(model_name) + + trainer = SFTTrainer( + model, + args=training_args, + processing_class=tokenizer, + train_dataset=self.train_dataset, + eval_dataset=self.eval_dataset, + ) + + trainer.train() + + release_memory(model, trainer) + + @pytest.mark.parametrize("device_map", [{"": 0}, "auto"]) + @pytest.mark.parametrize( + "gradient_checkpointing_kwargs", [None, {"use_reentrant": False}, {"use_reentrant": True}] + ) + @pytest.mark.parametrize("packing", [True, False]) + @pytest.mark.parametrize( + "model_name", + [ + "trl-internal-testing/tiny-LlamaForCausalLM-3.2", + "trl-internal-testing/tiny-MistralForCausalLM-0.2", + ], + ) + @require_torch_multi_accelerator + def test_sft_trainer_transformers_mp_gc_device_map( + self, model_name, packing, gradient_checkpointing_kwargs, device_map + ): + """ + Simply tests if passing a transformers model to `SFTTrainer` loads and runs the trainer as expected in mixed + precision + different scenarios of gradient_checkpointing (single, multi-gpu, etc). + """ + training_args = SFTConfig( + output_dir=self.tmp_dir, + logging_strategy="no", + report_to="none", + per_device_train_batch_size=2, + max_steps=10, + packing=packing, + max_length=self.max_length, + fp16=True, # this is sufficient to enable amp + gradient_checkpointing=True, # default, here for clarity + gradient_checkpointing_kwargs=gradient_checkpointing_kwargs, + ) + + model = AutoModelForCausalLM.from_pretrained(model_name, dtype="float32", device_map=device_map) + tokenizer = AutoTokenizer.from_pretrained(model_name) + + trainer = SFTTrainer( + model, + args=training_args, + processing_class=tokenizer, + train_dataset=self.train_dataset, + eval_dataset=self.eval_dataset, + ) + + trainer.train() + + release_memory(model, trainer) + + @pytest.mark.parametrize( + "gradient_checkpointing_kwargs", [None, {"use_reentrant": False}, {"use_reentrant": True}] + ) + @pytest.mark.parametrize("packing", [True, False]) + @pytest.mark.parametrize( + "model_name", + [ + "trl-internal-testing/tiny-LlamaForCausalLM-3.2", + "trl-internal-testing/tiny-MistralForCausalLM-0.2", + ], + ) + @require_peft + @require_bitsandbytes + def test_sft_trainer_transformers_mp_gc_peft_qlora(self, model_name, packing, gradient_checkpointing_kwargs): + """ + Simply tests if passing a transformers model + PEFT + bnb to `SFTTrainer` loads and runs the trainer as + expected in mixed precision + different scenarios of gradient_checkpointing. + """ + training_args = SFTConfig( + output_dir=self.tmp_dir, + logging_strategy="no", + report_to="none", + per_device_train_batch_size=2, + max_steps=10, + packing=packing, + max_length=self.max_length, + gradient_checkpointing=True, # default, here for clarity + gradient_checkpointing_kwargs=gradient_checkpointing_kwargs, + ) + + quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16) + + model = AutoModelForCausalLM.from_pretrained( + model_name, dtype="float32", quantization_config=quantization_config + ) + tokenizer = AutoTokenizer.from_pretrained(model_name) + + trainer = SFTTrainer( + model, + args=training_args, + processing_class=tokenizer, + train_dataset=self.train_dataset, + eval_dataset=self.eval_dataset, + peft_config=self.peft_config, + ) + + assert isinstance(trainer.model, PeftModel) + + trainer.train() + + release_memory(model, trainer) + + @pytest.mark.parametrize("packing", [True, False]) + @pytest.mark.parametrize( + "model_name", + [ + "trl-internal-testing/tiny-LlamaForCausalLM-3.2", + "trl-internal-testing/tiny-MistralForCausalLM-0.2", + ], + ) + @require_peft + @require_bitsandbytes + def test_sft_trainer_with_chat_format_qlora(self, model_name, packing): + """ + Simply tests if using setup_chat_format with a transformers model + peft + bnb config to `SFTTrainer` loads and + runs the trainer as expected. + """ + train_dataset = load_dataset("trl-internal-testing/dolly-chatml-sft", split="train") + + training_args = SFTConfig( + packing=packing, + max_length=self.max_length, + output_dir=self.tmp_dir, + logging_strategy="no", + report_to="none", + per_device_train_batch_size=2, + max_steps=10, + ) + + quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16) + + model = AutoModelForCausalLM.from_pretrained( + model_name, dtype="float32", quantization_config=quantization_config + ) + tokenizer = AutoTokenizer.from_pretrained(model_name) + + trainer = SFTTrainer( + model, + args=training_args, + processing_class=tokenizer, + train_dataset=train_dataset, + peft_config=self.peft_config, + ) + + assert isinstance(trainer.model, PeftModel) + + trainer.train() + + release_memory(model, trainer) + + @pytest.mark.parametrize("packing", [True, False]) + @pytest.mark.parametrize( + "model_name", + [ + "trl-internal-testing/tiny-LlamaForCausalLM-3.2", + "trl-internal-testing/tiny-MistralForCausalLM-0.2", + ], + ) + @require_liger_kernel + def test_sft_trainer_with_liger(self, model_name, packing): + """ + Tests if passing use_liger=True to SFTConfig loads and runs the trainer with AutoLigerKernelForCausalLM as + expected. + """ + import importlib + + def cleanup_liger_patches(trainer): + """Clean up liger_kernel patches by reloading the model's specific module""" + try: + # Get the specific module that was used by the trainer's model + module_path = trainer.model.__module__ + reload_module = importlib.import_module(module_path) + importlib.reload(reload_module) + except Exception: + pass # Continue if reload fails + + training_args = SFTConfig( + output_dir=self.tmp_dir, + logging_strategy="no", + report_to="none", + per_device_train_batch_size=2, + max_steps=2, + packing=packing, + max_length=self.max_length, + use_liger_kernel=True, + ) + + trainer = SFTTrainer( + model_name, + args=training_args, + train_dataset=self.train_dataset, + eval_dataset=self.eval_dataset, + ) + + # Ensure cleanup of liger patches after the test + try: + trainer.train() + release_memory(trainer.model, trainer) + finally: + cleanup_liger_patches(trainer) + + @pytest.mark.parametrize("packing", [True, False]) + @pytest.mark.parametrize( + "model_name", + [ + "trl-internal-testing/tiny-LlamaForCausalLM-3.2", + "trl-internal-testing/tiny-MistralForCausalLM-0.2", + ], + ) + @require_torch_accelerator + def test_train_offloading(self, model_name, packing): + """Test that activation offloading works with SFTTrainer.""" + # Initialize the trainer + training_args = SFTConfig( + output_dir=self.tmp_dir, + activation_offloading=True, + report_to="none", + per_device_train_batch_size=2, + max_steps=2, + packing=packing, + max_length=self.max_length, + ) + trainer = SFTTrainer( + model=model_name, args=training_args, train_dataset=self.train_dataset, eval_dataset=self.eval_dataset + ) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" + + release_memory(trainer.model, trainer) diff --git a/ICL/RL/trl_source/tests/test_vllm_client_server.py b/ICL/RL/trl_source/tests/test_vllm_client_server.py new file mode 100644 index 0000000000000000000000000000000000000000..e3464dae6d69e438db81c1340aa959dd1cd5fa98 --- /dev/null +++ b/ICL/RL/trl_source/tests/test_vllm_client_server.py @@ -0,0 +1,616 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import subprocess + +import pytest +from transformers import AutoModelForCausalLM +from transformers.testing_utils import torch_device + +from trl.generation.vllm_client import VLLMClient +from trl.import_utils import is_vllm_available +from trl.scripts.vllm_serve import chunk_list + +from .testing_utils import ( + TrlTestCase, + kill_process, + require_3_accelerators, + require_torch_multi_accelerator, + require_vllm, +) + + +if is_vllm_available(): + from vllm import LLM, SamplingParams + + +class TestChunkList(TrlTestCase): + def test_even_split(self): + assert chunk_list([1, 2, 3, 4, 5, 6], 2) == [[1, 2, 3], [4, 5, 6]] + + def test_uneven_split(self): + assert chunk_list([1, 2, 3, 4, 5, 6], 4) == [[1, 2], [3, 4], [5], [6]] + + def test_more_chunks_than_elements(self): + assert chunk_list([1, 2, 3, 4, 5, 6], 8) == [[1], [2], [3], [4], [5], [6], [], []] + + def test_n_equals_len(self): + assert chunk_list([1, 2, 3], 3) == [[1], [2], [3]] + + def test_n_is_1(self): + assert chunk_list([1, 2, 3], 1) == [[1, 2, 3]] + + def test_single_element_list(self): + assert chunk_list([42], 2) == [[42], []] + + def test_any_dtype(self): + assert chunk_list([1, "two", 3.0, {"four": 4}, ["f", "i", "v", "e"]], 2) == [ + [1, "two", 3.0], + [{"four": 4}, ["f", "i", "v", "e"]], + ] + + +@pytest.mark.slow +@require_torch_multi_accelerator +@require_vllm +class TestVLLMClientServer(TrlTestCase): + model_id = "Qwen/Qwen2.5-1.5B" + + @classmethod + def setup_class(cls): + # We want the server to run on accelerator 1, so we set VISIBLE_DEVICES to "1" + env = os.environ.copy() + VISIBLE_DEVICES = "ZE_AFFINITY_MASK" if torch_device == "xpu" else "CUDA_VISIBLE_DEVICES" + env[VISIBLE_DEVICES] = "1" # Restrict to accelerator 1 + + # Start the server process + cls.server_process = subprocess.Popen( + ["trl", "vllm-serve", "--model", cls.model_id], stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env + ) + + # Initialize the client + cls.client = VLLMClient(connection_timeout=240, host="localhost") + cls.client.init_communicator() + + def test_generate(self): + prompts = ["Hello, AI!", "Tell me a joke"] + outputs = self.client.generate(prompts) + prompt_ids = outputs["prompt_ids"] + completion_ids = outputs["completion_ids"] + + # Check that the outputs are lists + assert isinstance(prompt_ids, list) + assert isinstance(completion_ids, list) + + # Check that the number of sequences are equal to the number of prompts + assert len(prompt_ids) == len(prompts) + assert len(completion_ids) == len(prompts) + + # Check that the sequences are lists of integers + for seq in prompt_ids: + assert all(isinstance(tok, int) for tok in seq) + for seq in completion_ids: + assert all(isinstance(tok, int) for tok in seq) + + def test_chat(self): + messages = [[{"role": "user", "content": "Hello, AI!"}], [{"role": "user", "content": "Tell me a joke"}]] + outputs = self.client.chat(messages) + prompt_ids = outputs["prompt_ids"] + completion_ids = outputs["completion_ids"] + + # Check that the outputs are lists + assert isinstance(prompt_ids, list) + assert isinstance(completion_ids, list) + + # Check that the number of sequences are equal to the number of messages + assert len(prompt_ids) == len(messages) + assert len(completion_ids) == len(messages) + + # Check that the sequences are lists of integers + for seq in prompt_ids: + assert all(isinstance(tok, int) for tok in seq) + for seq in completion_ids: + assert all(isinstance(tok, int) for tok in seq) + + def test_generate_with_params(self): + prompts = ["Hello, AI!", "Tell me a joke"] + completion_ids = self.client.generate(prompts, n=2, repetition_penalty=0.9, temperature=0.8, max_tokens=32)[ + "completion_ids" + ] + + # Check that the output is a list + assert isinstance(completion_ids, list) + + # Check that the number of generated sequences is 2 times the number of prompts + assert len(completion_ids) == 2 * len(prompts) + + # Check that the generated sequences are lists of integers + for seq in completion_ids: + assert all(isinstance(tok, int) for tok in seq) + + # Check that the length of the generated sequences is less than or equal to 32 + for seq in completion_ids: + assert len(seq) <= 32 + + def test_update_model_params(self): + model = AutoModelForCausalLM.from_pretrained(self.model_id, device_map=torch_device) + self.client.update_model_params(model) + + def test_reset_prefix_cache(self): + # Test resetting the prefix cache + self.client.reset_prefix_cache() + + def test_chat_completions_endpoint(self): + data = self.client.chat_completions( + messages=[{"role": "user", "content": "Say hello"}], + max_tokens=32, + ) + + assert "id" in data + assert "choices" in data + assert "usage" in data + assert len(data["choices"]) > 0 + assert data["choices"][0]["message"]["role"] == "assistant" + assert data["choices"][0]["finish_reason"] in ["stop", "length", "tool_calls"] + + def test_chat_completions_with_tools(self): + tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get weather information for a location", + "parameters": {"type": "object", "properties": {"location": {"type": "string"}}}, + }, + } + ] + data = self.client.chat_completions( + messages=[{"role": "user", "content": "What's the weather in San Francisco?"}], + tools=tools, + max_tokens=100, + ) + + assert "choices" in data + assert len(data["choices"]) > 0 + assert "message" in data["choices"][0] + + def test_chat_completions_with_params(self): + data = self.client.chat_completions( + messages=[{"role": "user", "content": "Tell me a joke"}], + n=2, + temperature=0.8, + top_p=0.9, + max_tokens=32, + ) + + assert len(data["choices"]) == 2 + + for i, choice in enumerate(data["choices"]): + assert choice["index"] == i, f"Expected choice at position {i} to have index {i}, got {choice['index']}" + assert "message" in choice + assert choice["message"]["role"] == "assistant" + + def test_tokenize_endpoint(self): + data = self.client.tokenize(messages=[{"role": "user", "content": "Hello, how are you?"}]) + + assert "tokens" in data + assert "model" in data + assert isinstance(data["tokens"], list) + assert len(data["tokens"]) > 0 + assert all(isinstance(tok, int) for tok in data["tokens"]) + + @pytest.mark.xfail(reason="Importing `bitsandbytes` causes issues, see vllm-project/vllm#32793") + def test_logprobs_match_with_non_default_sampling(self): + prompts = ["Hello, AI!", "Tell me a joke"] + # Use non-default sampling parameters (especially temperature) to ensure vLLM applies logprob processing. With + # default sampling, raw and processed logprobs are identical, so mismatches would not be detected. + temperature = 0.7 + repetition_penalty = 1.05 + top_p = 0.9 + max_tokens = 8 + seed = 1234 + + server_outputs = self.client.generate( + prompts, + temperature=temperature, + repetition_penalty=repetition_penalty, + top_p=top_p, + max_tokens=max_tokens, + generation_kwargs={"seed": seed}, + ) + os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + llm = LLM( + model=self.model_id, + tensor_parallel_size=1, + gpu_memory_utilization=0.2, + max_model_len=128, + logprobs_mode="processed_logprobs", + ) + + sampling_params = SamplingParams( + temperature=temperature, + repetition_penalty=repetition_penalty, + top_p=top_p, + max_tokens=max_tokens, + logprobs=0, # this is what's used in practice to get the logprobs of generated tokens + seed=seed, + ) + colocate_outputs = llm.generate(prompts, sampling_params=sampling_params, use_tqdm=False) + colocate_prompt_ids = [output.prompt_token_ids for output in colocate_outputs] + colocate_completion_ids = [ + list(output.token_ids) for outputs in colocate_outputs for output in outputs.outputs + ] + colocate_logprobs = [ + [next(iter(logprob.values())).logprob for logprob in output.logprobs] + for outputs in colocate_outputs + for output in outputs.outputs + ] + + assert server_outputs["prompt_ids"] == colocate_prompt_ids + assert server_outputs["completion_ids"] == colocate_completion_ids + server_logprobs = server_outputs["logprobs"] + assert len(server_logprobs) == len(colocate_logprobs) + for server_seq, colocate_seq in zip(server_logprobs, colocate_logprobs, strict=True): + assert len(server_seq) == len(colocate_seq) + assert server_seq == pytest.approx(colocate_seq, rel=1e-6, abs=1e-6) + + @classmethod + def teardown_class(cls): + # Close the client + cls.client.close_communicator() + + # vLLM x pytest (or Popen) seems not to handle process termination well. To avoid zombie processes, we need to + # kill the server process and its children explicitly. + kill_process(cls.server_process) + + +# Same as above but using base_url to instantiate the client. +@pytest.mark.slow +@require_torch_multi_accelerator +@require_vllm +class TestVLLMClientServerBaseURL(TrlTestCase): + model_id = "Qwen/Qwen2.5-1.5B" + + @classmethod + def setup_class(cls): + # We want the server to run on accelerator 1, so we set VISIBLE_DEVICES to "1" + env = os.environ.copy() + VISIBLE_DEVICES = "ZE_AFFINITY_MASK" if torch_device == "xpu" else "CUDA_VISIBLE_DEVICES" + env[VISIBLE_DEVICES] = "1" # Restrict to accelerator 1 + + # Start the server process + cls.server_process = subprocess.Popen( + ["trl", "vllm-serve", "--model", cls.model_id], stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env + ) + + # Initialize the client + cls.client = VLLMClient(base_url="http://localhost:8000", connection_timeout=240) + cls.client.init_communicator() + + def test_generate(self): + prompts = ["Hello, AI!", "Tell me a joke"] + outputs = self.client.generate(prompts) + prompt_ids = outputs["prompt_ids"] + completion_ids = outputs["completion_ids"] + + # Check that the outputs are lists + assert isinstance(prompt_ids, list) + assert isinstance(completion_ids, list) + + # Check that the number of sequences are equal to the number of prompts + assert len(prompt_ids) == len(prompts) + assert len(completion_ids) == len(prompts) + + # Check that the sequences are lists of integers + for seq in prompt_ids: + assert all(isinstance(tok, int) for tok in seq) + for seq in completion_ids: + assert all(isinstance(tok, int) for tok in seq) + + def test_chat(self): + messages = [[{"role": "user", "content": "Hello, AI!"}], [{"role": "user", "content": "Tell me a joke"}]] + outputs = self.client.chat(messages) + prompt_ids = outputs["prompt_ids"] + completion_ids = outputs["completion_ids"] + + # Check that the outputs are lists + assert isinstance(prompt_ids, list) + assert isinstance(completion_ids, list) + + # Check that the number of sequences are equal to the number of messages + assert len(prompt_ids) == len(messages) + assert len(completion_ids) == len(messages) + + # Check that the sequences are lists of integers + for seq in prompt_ids: + assert all(isinstance(tok, int) for tok in seq) + for seq in completion_ids: + assert all(isinstance(tok, int) for tok in seq) + + def test_generate_with_params(self): + prompts = ["Hello, AI!", "Tell me a joke"] + completion_ids = self.client.generate(prompts, n=2, repetition_penalty=0.9, temperature=0.8, max_tokens=32)[ + "completion_ids" + ] + + # Check that the output is a list + assert isinstance(completion_ids, list) + + # Check that the number of generated sequences is 2 times the number of prompts + assert len(completion_ids) == 2 * len(prompts) + + # Check that the generated sequences are lists of integers + for seq in completion_ids: + assert all(isinstance(tok, int) for tok in seq) + + # Check that the length of the generated sequences is less than or equal to 32 + for seq in completion_ids: + assert len(seq) <= 32 + + def test_update_model_params(self): + model = AutoModelForCausalLM.from_pretrained(self.model_id, device_map=torch_device) + self.client.update_model_params(model) + + def test_reset_prefix_cache(self): + # Test resetting the prefix cache + self.client.reset_prefix_cache() + + @classmethod + def teardown_class(cls): + # Close the client + cls.client.close_communicator() + + # vLLM x pytest (or Popen) seems not to handle process termination well. To avoid zombie processes, we need to + # kill the server process and its children explicitly. + kill_process(cls.server_process) + + +@pytest.mark.slow +@require_3_accelerators +@require_vllm +class TestVLLMClientServerTP(TrlTestCase): + model_id = "Qwen/Qwen2.5-1.5B" + + @classmethod + def setup_class(cls): + # We want the server to run on accelerator 1 and 2, so we set VISIBLE_DEVICES to "1,2" + env = os.environ.copy() + VISIBLE_DEVICES = "ZE_AFFINITY_MASK" if torch_device == "xpu" else "CUDA_VISIBLE_DEVICES" + env[VISIBLE_DEVICES] = "1,2" # Restrict to accelerator 1 and 2 + + # Start the server process + cls.server_process = subprocess.Popen( + ["trl", "vllm-serve", "--model", cls.model_id, "--tensor_parallel_size", "2"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + env=env, + ) + + # Initialize the client + cls.client = VLLMClient(connection_timeout=240, host="localhost") + cls.client.init_communicator() + + def test_generate(self): + prompts = ["Hello, AI!", "Tell me a joke"] + outputs = self.client.generate(prompts) + prompt_ids = outputs["prompt_ids"] + completion_ids = outputs["completion_ids"] + + # Check that the outputs are lists + assert isinstance(prompt_ids, list) + assert isinstance(completion_ids, list) + + # Check that the number of sequences are equal to the number of prompts + assert len(prompt_ids) == len(prompts) + assert len(completion_ids) == len(prompts) + + # Check that the sequences are lists of integers + for seq in prompt_ids: + assert all(isinstance(tok, int) for tok in seq) + for seq in completion_ids: + assert all(isinstance(tok, int) for tok in seq) + + def test_chat(self): + messages = [[{"role": "user", "content": "Hello, AI!"}], [{"role": "user", "content": "Tell me a joke"}]] + outputs = self.client.chat(messages) + prompt_ids = outputs["prompt_ids"] + completion_ids = outputs["completion_ids"] + + # Check that the outputs are lists + assert isinstance(prompt_ids, list) + assert isinstance(completion_ids, list) + + # Check that the number of sequences are equal to the number of messages + assert len(prompt_ids) == len(messages) + assert len(completion_ids) == len(messages) + + # Check that the sequences are lists of integers + for seq in prompt_ids: + assert all(isinstance(tok, int) for tok in seq) + for seq in completion_ids: + assert all(isinstance(tok, int) for tok in seq) + + def test_update_model_params(self): + model = AutoModelForCausalLM.from_pretrained(self.model_id, device_map=torch_device) + self.client.update_model_params(model) + + def test_reset_prefix_cache(self): + # Test resetting the prefix cache + self.client.reset_prefix_cache() + + @classmethod + def teardown_class(cls): + # Close the client + cls.client.close_communicator() + + # vLLM x pytest (or Popen) seems not to handle process termination well. To avoid zombie processes, we need to + # kill the server process and its children explicitly. + kill_process(cls.server_process) + + +@pytest.mark.slow +@require_3_accelerators +@require_vllm +class TestVLLMClientServerDP(TrlTestCase): + model_id = "Qwen/Qwen2.5-1.5B" + + @classmethod + def setup_class(cls): + # We want the server to run on accelerator 1 and 2, so we set VISIBLE_DEVICES to "1,2" + env = os.environ.copy() + VISIBLE_DEVICES = "ZE_AFFINITY_MASK" if torch_device == "xpu" else "CUDA_VISIBLE_DEVICES" + env[VISIBLE_DEVICES] = "1,2" # Restrict to accelerator 1 and 2 + + # Start the server process + cls.server_process = subprocess.Popen( + ["trl", "vllm-serve", "--model", cls.model_id, "--data_parallel_size", "2"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + env=env, + ) + + # Initialize the client + cls.client = VLLMClient(connection_timeout=240, host="localhost") + cls.client.init_communicator() + + def test_generate(self): + prompts = ["Hello, AI!", "Tell me a joke"] + outputs = self.client.generate(prompts) + prompt_ids = outputs["prompt_ids"] + completion_ids = outputs["completion_ids"] + + # Check that the outputs are lists + assert isinstance(prompt_ids, list) + assert isinstance(completion_ids, list) + + # Check that the number of sequences are equal to the number of prompts + assert len(prompt_ids) == len(prompts) + assert len(completion_ids) == len(prompts) + + # Check that the sequences are lists of integers + for seq in prompt_ids: + assert all(isinstance(tok, int) for tok in seq) + for seq in completion_ids: + assert all(isinstance(tok, int) for tok in seq) + + def test_chat(self): + messages = [[{"role": "user", "content": "Hello, AI!"}], [{"role": "user", "content": "Tell me a joke"}]] + outputs = self.client.chat(messages) + prompt_ids = outputs["prompt_ids"] + completion_ids = outputs["completion_ids"] + + # Check that the outputs are lists + assert isinstance(prompt_ids, list) + assert isinstance(completion_ids, list) + + # Check that the number of sequences are equal to the number of messages + assert len(prompt_ids) == len(messages) + assert len(completion_ids) == len(messages) + + # Check that the sequences are lists of integers + for seq in prompt_ids: + assert all(isinstance(tok, int) for tok in seq) + for seq in completion_ids: + assert all(isinstance(tok, int) for tok in seq) + + def test_update_model_params(self): + model = AutoModelForCausalLM.from_pretrained(self.model_id, device_map=torch_device) + self.client.update_model_params(model) + + def test_reset_prefix_cache(self): + # Test resetting the prefix cache + self.client.reset_prefix_cache() + + @classmethod + def teardown_class(cls): + # Close the client + cls.client.close_communicator() + + # vLLM x pytest (or Popen) seems not to handle process termination well. To avoid zombie processes, we need to + # kill the server process and its children explicitly. + kill_process(cls.server_process) + + +@pytest.mark.slow +@require_torch_multi_accelerator +@require_vllm +class TestVLLMClientServerDeviceParameter(TrlTestCase): + """Test the device parameter functionality in init_communicator.""" + + model_id = "Qwen/Qwen2.5-1.5B" + + @classmethod + def setup_class(cls): + # We want the server to run on accelerator 1, so we set VISIBLE_DEVICES to "1" + env = os.environ.copy() + VISIBLE_DEVICES = "ZE_AFFINITY_MASK" if torch_device == "xpu" else "CUDA_VISIBLE_DEVICES" + env[VISIBLE_DEVICES] = "1" # Restrict to accelerator 1 + + # Start the server process + cls.server_process = subprocess.Popen( + ["trl", "vllm-serve", "--model", cls.model_id], stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env + ) + + def test_init_communicator_with_device_int(self): + """Test init_communicator with integer device parameter.""" + client = VLLMClient(connection_timeout=240, host="localhost") + client.init_communicator(device=0) # Explicitly specify device 0 + + # Test basic functionality + prompts = ["Hello, AI!"] + outputs = client.generate(prompts) + prompt_ids = outputs["prompt_ids"] + completion_ids = outputs["completion_ids"] + assert isinstance(prompt_ids, list) + assert len(prompt_ids) == len(prompts) + assert isinstance(completion_ids, list) + assert len(completion_ids) == len(prompts) + + client.close_communicator() + + def test_init_communicator_with_device_string(self): + """Test init_communicator with string device parameter.""" + client = VLLMClient(connection_timeout=240, host="localhost") + client.init_communicator(device=0) # Explicitly specify device as string + + # Test basic functionality + prompts = ["Hello, AI!"] + outputs = client.generate(prompts)["completion_ids"] + assert isinstance(outputs, list) + assert len(outputs) == len(prompts) + + client.close_communicator() + + def test_init_communicator_with_torch_device(self): + """Test init_communicator with torch.device object.""" + import torch + + client = VLLMClient(connection_timeout=240, host="localhost") + device = torch.device(0) + client.init_communicator(device=device) # Explicitly specify torch.device object + + # Test basic functionality + prompts = ["Hello, AI!"] + outputs = client.generate(prompts)["completion_ids"] + assert isinstance(outputs, list) + assert len(outputs) == len(prompts) + + client.close_communicator() + + @classmethod + def teardown_class(cls): + # vLLM x pytest (or Popen) seems not to handle process termination well. To avoid zombie processes, we need to + # kill the server process and its children explicitly. + kill_process(cls.server_process) diff --git a/ICL/RL/trl_source/tests/testing_utils.py b/ICL/RL/trl_source/tests/testing_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4522cd7f0c68f1a8118ae4cfff4b627719976c18 --- /dev/null +++ b/ICL/RL/trl_source/tests/testing_utils.py @@ -0,0 +1,146 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +import signal +import warnings +from collections.abc import Callable + +import psutil +import pytest +import torch +from transformers import is_bitsandbytes_available, is_comet_available, is_sklearn_available, is_wandb_available +from transformers.testing_utils import backend_device_count, torch_device +from transformers.utils import ( + is_kernels_available, + is_peft_available, + is_rich_available, + is_torch_available, + is_vision_available, +) + +from trl.import_utils import ( + is_jmespath_available, + is_joblib_available, + is_liger_kernel_available, + is_llm_blender_available, + is_math_verify_available, + is_mergekit_available, + is_vllm_available, +) + + +require_bitsandbytes = pytest.mark.skipif(not is_bitsandbytes_available(), reason="test requires bitsandbytes") +require_comet = pytest.mark.skipif(not is_comet_available(), reason="test requires comet_ml") +require_jmespath = pytest.mark.skipif(not is_jmespath_available(), reason="test requires jmespath") +require_kernels = pytest.mark.skipif(not is_kernels_available(), reason="test requires kernels") +require_liger_kernel = pytest.mark.skipif(not is_liger_kernel_available(), reason="test requires liger-kernel") +require_llm_blender = pytest.mark.skipif(not is_llm_blender_available(), reason="test requires llm-blender") +require_math_latex = pytest.mark.skipif(not is_math_verify_available(), reason="test requires math_verify") +require_mergekit = pytest.mark.skipif(not is_mergekit_available(), reason="test requires mergekit") +require_peft = pytest.mark.skipif(not is_peft_available(), reason="test requires peft") +require_rich = pytest.mark.skipif(not is_rich_available(), reason="test requires rich") +require_sklearn = pytest.mark.skipif( + not (is_sklearn_available() and is_joblib_available()), reason="test requires sklearn" +) +require_torch_accelerator = pytest.mark.skipif( + torch_device is None or torch_device == "cpu", reason="test requires accelerator" +) +require_torch_multi_accelerator = pytest.mark.skipif( + not is_torch_available() or backend_device_count(torch_device) <= 1, reason="test requires multiple accelerators" +) +require_vision = pytest.mark.skipif(not is_vision_available(), reason="test requires vision") +require_vllm = pytest.mark.skipif(not is_vllm_available(), reason="test requires vllm") +require_wandb = pytest.mark.skipif(not is_wandb_available(), reason="test requires wandb") +require_no_wandb = pytest.mark.skipif(is_wandb_available(), reason="test requires no wandb") +require_3_accelerators = pytest.mark.skipif( + not (getattr(torch, torch_device, torch.cuda).device_count() >= 3), + reason=f"test requires at least 3 {torch_device}s", +) + + +def is_bitsandbytes_multi_backend_available() -> bool: + if is_bitsandbytes_available(): + import bitsandbytes as bnb + + return "multi_backend" in getattr(bnb, "features", set()) + return False + + +# Function ported from transformers.testing_utils before transformers#41283 +require_torch_gpu_if_bnb_not_multi_backend_enabled = pytest.mark.skipif( + not is_bitsandbytes_multi_backend_available() and not torch_device == "cuda", + reason="test requires bitsandbytes multi-backend enabled or 'cuda' torch device", +) + + +def is_ampere_or_newer(device_index=0): + if not torch.cuda.is_available(): + return False + + major, minor = torch.cuda.get_device_capability(device_index) + # Ampere starts at compute capability 8.0 (e.g., A100 = 8.0, RTX 30xx = 8.6) + return (major, minor) >= (8, 0) + + +require_ampere_or_newer = pytest.mark.skipif(not is_ampere_or_newer(), reason="test requires Ampere or newer GPU") + + +class TrlTestCase: + @pytest.fixture(autouse=True) + def set_tmp_dir(self, tmp_path): + self.tmp_dir = str(tmp_path) + + +def ignore_warnings(message: str = None, category: type[Warning] = Warning) -> Callable: + """ + Decorator to ignore warnings with a specific message and/or category. + + Args: + message (`str`, *optional*): + Regex pattern for the warning message to ignore. If `None`, all messages are ignored. + category (`type[Warning]`, *optional*, defaults to `Warning`): + Warning class to ignore. Defaults to `Warning`, which ignores all warnings. + """ + + def decorator(test_func): + @functools.wraps(test_func) + def wrapper(*args, **kwargs): + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", message=message, category=category) + return test_func(*args, **kwargs) + + return wrapper + + return decorator + + +def kill_process(process): + parent = psutil.Process(process.pid) + children = parent.children(recursive=True) + for child in children: + try: + child.send_signal(signal.SIGTERM) + child.wait(timeout=5) + except psutil.TimeoutExpired: + child.kill() + except psutil.NoSuchProcess: + pass + try: + process.terminate() + process.wait(timeout=5) + except psutil.TimeoutExpired: + process.kill() + except psutil.NoSuchProcess: + pass diff --git a/ICL/RL/visualize_rewards.py b/ICL/RL/visualize_rewards.py new file mode 100644 index 0000000000000000000000000000000000000000..20fefadad339c2ef95a454673ff1f970a969f03c --- /dev/null +++ b/ICL/RL/visualize_rewards.py @@ -0,0 +1,203 @@ +""" +Utility script for debugging and visualization. +""" +import matplotlib.pyplot as plt +import numpy as np +from typing import List, Dict + + +def visualize_reward_landscape(): + """ + Visualize how rewards change with different combinations of + correctness, retrieval steps, and SigLIP scores. + """ + from reward_functions import compute_reward + + fig, axes = plt.subplots(2, 2, figsize=(12, 10)) + fig.suptitle('GRPO Reward Landscape', fontsize=16) + + # 1. Reward vs. Number of Steps (when correct) + ax = axes[0, 0] + steps = [0, 1, 2, 3] + siglip_scores_variants = [ + [0.97] * 3, # Perfect retrieval + [0.65] * 3, # Medium retrieval + [0.33] * 3, # Poor retrieval + ] + labels = ['Perfect (0.97)', 'Medium (0.65)', 'Poor (0.33)'] + + for scores, label in zip(siglip_scores_variants, labels): + rewards = [] + for n_steps in steps: + if n_steps == 0: + result = compute_reward("", "answer", [], timeout=False) + result["r_outcome"] = 1.0 # Force correct + else: + traj = " ".join([" desc"] * n_steps) + " answer" + result = compute_reward(traj, "answer", scores[:n_steps], timeout=False) + result["r_outcome"] = 1.0 # Force correct + # Recompute with correct answer + result = compute_reward(traj, "answer", scores[:n_steps], timeout=False) + + rewards.append(result["r_total"]) + + ax.plot(steps, rewards, marker='o', label=label, linewidth=2) + + ax.set_xlabel('Number of Retrieval Steps') + ax.set_ylabel('Total Reward') + ax.set_title('Reward vs. Steps (Correct Answer)') + ax.legend() + ax.grid(True, alpha=0.3) + + # 2. Reward vs. SigLIP Score (1-shot, correct) + ax = axes[0, 1] + siglip_range = np.linspace(0.3, 1.0, 50) + rewards = [] + + for score in siglip_range: + traj = " description answer" + result = compute_reward(traj, "answer", [score], timeout=False) + result["r_outcome"] = 1.0 # Force correct + # Recompute + from reward_functions import scale_siglip_score + r_total = 1.0 - 0.1 + scale_siglip_score(score) + rewards.append(r_total) + + ax.plot(siglip_range, rewards, linewidth=2, color='green') + ax.axvline(x=0.33, color='red', linestyle='--', alpha=0.5, label='Min threshold') + ax.axvline(x=0.97, color='blue', linestyle='--', alpha=0.5, label='Max threshold') + ax.set_xlabel('SigLIP Similarity Score') + ax.set_ylabel('Total Reward') + ax.set_title('Reward vs. SigLIP Score (1-shot, Correct)') + ax.legend() + ax.grid(True, alpha=0.3) + + # 3. Reward Breakdown by Scenario + ax = axes[1, 0] + scenarios = [ + 'Direct\nWrong', + '1-shot\nPerfect', + '1-shot\nMedium', + '1-shot\nBad', + '3-shot\nPerfect', + 'Direct\nCorrect', + 'Timeout' + ] + r_outcome = [-1.0, 1.0, 1.0, 1.0, 1.0, 1.0, -1.0] + r_penalty = [0.0, -0.1, -0.1, -0.1, -0.5, 0.0, -2.0] + r_rel = [0.0, 1.0, 0.5, 0.0, 1.0, 0.0, 0.0] + + x = np.arange(len(scenarios)) + width = 0.25 + + ax.bar(x - width, r_outcome, width, label='R_outcome', color='blue') + ax.bar(x, r_penalty, width, label='R_penalty', color='red') + ax.bar(x + width, r_rel, width, label='R_rel', color='green') + + ax.set_xlabel('Scenario') + ax.set_ylabel('Reward Component') + ax.set_title('Reward Breakdown by Scenario') + ax.set_xticks(x) + ax.set_xticklabels(scenarios, fontsize=8) + ax.legend() + ax.grid(True, alpha=0.3, axis='y') + ax.axhline(y=0, color='black', linestyle='-', linewidth=0.5) + + # 4. Total Reward Comparison + ax = axes[1, 1] + total_rewards = [sum(x) for x in zip(r_outcome, r_penalty, r_rel)] + colors = ['red' if r < 0 else 'green' for r in total_rewards] + + bars = ax.bar(scenarios, total_rewards, color=colors, alpha=0.7) + ax.set_xlabel('Scenario') + ax.set_ylabel('Total Reward') + ax.set_title('Total Reward by Scenario') + ax.axhline(y=0, color='black', linestyle='-', linewidth=1) + ax.grid(True, alpha=0.3, axis='y') + + # Add value labels on bars + for bar, reward in zip(bars, total_rewards): + height = bar.get_height() + ax.text(bar.get_x() + bar.get_width()/2., height, + f'{reward:.1f}', + ha='center', va='bottom' if height > 0 else 'top', + fontsize=9, fontweight='bold') + + plt.tight_layout() + plt.savefig('/workspace/xiaobin/RL/reward_landscape.png', dpi=150, bbox_inches='tight') + print("Saved reward landscape visualization to reward_landscape.png") + plt.close() + + +def print_reward_table(): + """Print the reward table from the plan.""" + print("\n" + "="*80) + print("REWARD TABLE (from plan.md)") + print("="*80) + print() + print("| Sample | Behavior | R_outcome | R_penalty | R_rel | Total |") + print("|--------|-----------------------|-----------|-----------|-------|-------|") + print("| 1 | Direct wrong | -1.0 | 0.0 | 0.0 | -1.0 |") + print("| 2 | 1-shot perfect | +1.0 | -0.1 | +1.0 | +1.9 |") + print("| 3 | 1-shot medium | +1.0 | -0.1 | +0.5 | +1.4 |") + print("| 4 | 1-shot bad | +1.0 | -0.1 | +0.0 | +0.9 |") + print("| 5 | 3-shot perfect | +1.0 | -0.5 | +1.0 | +1.5 |") + print("| 6 | Direct correct | +1.0 | 0.0 | 0.0 | +1.0 |") + print("| 7 | Timeout (death loop) | -1.0 | -2.0 | 0.0 | -3.0 |") + print() + print("="*80) + print() + + +def analyze_grpo_gradient(): + """Analyze GRPO gradient direction based on reward differences.""" + print("\n" + "="*80) + print("GRPO GRADIENT ANALYSIS") + print("="*80) + print() + print("Sample 2 (1-shot perfect, reward=1.9) is the winner!") + print() + print("Comparisons:") + print(" • Sample 2 (1.9) vs Sample 1 (-1.0): Δ = +2.9") + print(" → Learn to retrieve when needed") + print() + print(" • Sample 2 (1.9) vs Sample 5 (1.5): Δ = +0.4") + print(" → Learn to use fewer steps (1-shot > 3-shot)") + print() + print(" • Sample 2 (1.9) vs Sample 4 (0.9): Δ = +1.0") + print(" → Learn to write accurate descriptions (high R_rel)") + print() + print(" • Sample 2 (1.9) vs Sample 6 (1.0): Δ = +0.9") + print(" → For retrieval-needed questions, retrieve > guess") + print() + print(" • Sample 2 (1.9) vs Sample 7 (-3.0): Δ = +4.9") + print(" → Avoid infinite loops at all costs") + print() + print("="*80) + print() + + +def main(): + """Run all visualization and analysis.""" + print("\n" + "="*80) + print("GRPO REWARD VISUALIZATION & ANALYSIS") + print("="*80) + + print_reward_table() + analyze_grpo_gradient() + + print("\nGenerating reward landscape visualization...") + try: + visualize_reward_landscape() + print("✓ Visualization complete!") + except Exception as e: + print(f"✗ Visualization failed: {e}") + print(" (matplotlib may not be available in this environment)") + + print("\n" + "="*80) + print("ANALYSIS COMPLETE") + print("="*80 + "\n") + + +if __name__ == "__main__": + main() diff --git a/ICL/RL_DAPO/README.md b/ICL/RL_DAPO/README.md new file mode 100644 index 0000000000000000000000000000000000000000..33daf067c1dc44e113b4f9448094e3674f9ebe98 --- /dev/null +++ b/ICL/RL_DAPO/README.md @@ -0,0 +1,154 @@ +# RL_DAPO: DAPO Training for Retrieval-Augmented VQA + +从 GRPO (`xiaobin/RL`) 迁移到 DAPO,使用 verl 框架。 + +## DAPO vs GRPO 关键区别 + +| 特性 | GRPO (旧) | DAPO (新) | +|------|-----------|-----------| +| 框架 | TRL GRPOTrainer | verl RayDAPOTrainer | +| Clip | 对称 ε=0.2 | 非对称 ε_low=0.2, ε_high=0.28 | +| KL惩罚 | β=0.04 | 无 KL (use_kl=False) | +| Loss聚合 | seq-level | token-level (token-mean) | +| 动态采样 | 无 | Group Filtering (过滤全对/全错组) | +| 超长惩罚 | 无 | Overlong Buffer (线性惩罚) | +| Rollout | HF generate | vLLM (NorthServe 远程服务) | +| 分布式 | accelerate | Ray + FSDP | +| Reward | 完全保留 GRPO 的 reward 逻辑 | 同左 | +| Data | rl_train.jsonl | 转换为 parquet (同分布) | + +## 项目结构 + +``` +xiaobin/RL_DAPO/ +├── __init__.py +├── main_dapo.py # 入口: hydra + ray 启动 +├── dapo_ray_trainer.py # DAPO trainer (继承 RayPPOTrainer) +├── compute_score.py # verl reward manager 桥接函数 +├── reward_functions.py # 核心 reward 逻辑 (从 GRPO 保留) +├── environment.py # SigLIP 检索环境 (从 GRPO 保留) +├── prepare_data.py # GRPO jsonl -> verl parquet 转换 +├── run_dapo.sh # 训练启动脚本 +├── requirements.txt +├── config/ +│ └── dapo_trainer.yaml # DAPO 训练配置 +└── data/ # 转换后的 parquet 数据 (运行 prepare_data.py 生成) +``` + +## 环境配置 + +全部使用清华源。vLLM 不需要本地装,通过 NorthServe 远程服务。 + +### 安装顺序 + +```bash +# 1) PyTorch (CUDA 12) +pip install torch==2.4.0 torchvision torchaudio \ + --index-url https://download.pytorch.org/whl/cu121 + +# 2) Flash Attention (本地 whl) +pip install /workspace/flash_attn-2.8.3+cu12torch2.4cxx11abiTRUE-cp311-cp311-linux_x86_64.whl + +# 3) verl +pip install verl -i https://pypi.tuna.tsinghua.edu.cn/simple + +# 4) 其余依赖 +pip install -r /workspace/xiaobin/RL_DAPO/requirements.txt \ + -i https://pypi.tuna.tsinghua.edu.cn/simple +``` + +### vLLM 服务 (NorthServe) + +不需要本地安装 vLLM。通过 NorthServe 启动远程 vLLM 服务: + +```bash +# 启动 Qwen3-VL 模型服务 (根据实际模型路径和配置调整) +HOME=/root /workspace/NorthServe/northserve launch \ + --model-name qwen3vl-dapo \ + --served-model-name Qwen3-VL-SigLIP-VQA \ + --namespace bg-agentic-coding \ + --model-path /i_workspace/xiaobin/SFT_model/hf_qwen3vl_siglip_vqa_iter_0000881 \ + --volumes "i-xinsiyang-y4zy0sik0a:/i_workspace" \ + --replicas 1 \ + --gpus-per-pod 8 \ + --pods-per-job 1 \ + --profile generation \ + --backend vllm \ + --priority-class-name higher-priority-job \ + --extra-cmds "--trust-remote-code --max-model-len 4096 --max-num-seqs 64" \ + -y +``` + +验证服务: +```bash +curl -s http://10.51.6.110/v1/models | jq '.data[].id' +``` + +停止服务: +```bash +HOME=/root /workspace/NorthServe/northserve stop \ + --namespace bg-agentic-coding --model-name qwen3vl-dapo -y +``` + +## 数据准备 + +数据复用 GRPO 的 `rl_train.jsonl`,只需转换格式: + +```bash +python -m xiaobin.RL_DAPO.prepare_data \ + --input /workspace/xiaobin/RL_data/rl_train.jsonl \ + --output-dir /workspace/xiaobin/RL_DAPO/data \ + --val-ratio 0.02 +``` + +数据分布不变:50% positive (需要检索) + 50% negative (直接回答)。 + +## 训练 + +```bash +cd /workspace +bash xiaobin/RL_DAPO/run_dapo.sh +``` + +### 关键参数调整 + +在 `run_dapo.sh` 中根据硬件调整: + +```bash +# GPU 数量 +NNODES=1 # 节点数 +NGPUS_PER_NODE=8 # 每节点 GPU 数 + +# Batch size (根据显存调整) +train_prompt_bsz=64 # 训练 batch size (prompt 数) +n_resp_per_prompt=8 # 每个 prompt 生成的回复数 +train_prompt_mini_bsz=8 # mini-batch size + +# vLLM tensor parallel (根据模型大小) +GEN_TP=1 # 7B 用 1, 32B 用 4 + +# 如果显存不够,减小: +max_prompt_length=512 +max_response_length=1024 +``` + +### 单卡测试 + +```bash +NNODES=1 NGPUS_PER_NODE=1 GEN_TP=1 bash xiaobin/RL_DAPO/run_dapo.sh +``` + +## Reward 设计 + +完全保留 GRPO 的 reward 逻辑: + +``` +R_total = R_outcome + R_rel + R_penalty + R_format + +R_outcome: F1-based, 映射到 [-1, 1] +R_rel: SigLIP 检索相关性 (lightweight 模式下简化) +R_penalty: 步数惩罚 (0步=0, 1步=-0.05, 2步=-0.15, 3步=-0.3, timeout=-1.5) +R_format: 格式奖励 (有=+0.1, 无结构=-0.1) +``` + +DAPO 额外增加了 Overlong Buffer 惩罚:超过 `max_response_length - buffer_len` 的部分线性惩罚。 diff --git a/ICL/RL_DAPO/compute_score.py b/ICL/RL_DAPO/compute_score.py new file mode 100644 index 0000000000000000000000000000000000000000..042bfeea799b1a52345ff97111726f21003c568e --- /dev/null +++ b/ICL/RL_DAPO/compute_score.py @@ -0,0 +1,84 @@ +""" +compute_score bridge for verl's reward manager. + +verl's DAPORewardManager calls compute_score(data_source, solution_str, ground_truth, extra_info) +for each sample. We adapt our GRPO reward logic to this interface. + +Since SigLIP retrieval requires a running environment (model + candidate pools), +we provide two modes: + 1. Lightweight mode (default): reward based on F1 + format penalties only (no SigLIP) + 2. Full mode: with SigLIP retrieval scoring (requires environment initialization) + +For DAPO training, the lightweight mode is recommended as the primary reward signal, +since the rollout itself doesn't execute retrieval. SigLIP-based R_rel can be added +as a secondary reward via a custom reward manager if needed. +""" +import json +import os +from typing import Any, Dict, List, Optional + +from .reward_functions import ( + compute_f1_score, + count_retrieval_steps, + extract_answer, + scale_siglip_score, +) + + +def compute_score( + data_source: str, + solution_str: str, + ground_truth: Any, + extra_info: Optional[Dict] = None, +) -> float: + """ + Compute reward score for a single sample. + This is the function verl's reward manager will call. + + Args: + data_source: dataset name / reward_fn_key (e.g. "retrieval_vqa") + solution_str: model's generated response text + ground_truth: ground truth answer (str or dict) + extra_info: optional dict with additional info (image path, dataset, id, etc.) + + Returns: + float: reward score + """ + # Extract ground truth string + if isinstance(ground_truth, dict): + gt_str = ground_truth.get("answer", str(ground_truth)) + elif isinstance(ground_truth, list): + gt_str = ground_truth[0] if ground_truth else "" + else: + gt_str = str(ground_truth) + + # Extract answer from model output + prediction = extract_answer(solution_str) + num_steps = count_retrieval_steps(solution_str) + + # R_outcome: F1-based continuous reward + f1 = compute_f1_score(prediction, gt_str) + r_outcome = 2.0 * f1 - 1.0 + + # R_penalty: step-based penalty + has_answer = "" in solution_str + max_turns = 3 + num_rets = solution_str.count("") + timeout = num_rets >= max_turns and not has_answer + + if timeout: + r_penalty = -1.5 + else: + penalty_map = {0: 0.0, 1: -0.05, 2: -0.15, 3: -0.3} + r_penalty = penalty_map.get(num_steps, -0.3) + + # Format reward: encourage proper use of token + r_format = 0.0 + if has_answer: + r_format = 0.1 # small bonus for proper formatting + elif not has_answer and num_rets == 0: + r_format = -0.1 # penalty for no structure at all + + r_total = r_outcome + r_penalty + r_format + + return r_total diff --git a/ICL/RL_DAPO/dapo_ray_trainer.py b/ICL/RL_DAPO/dapo_ray_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..6ca97232755122f0b7beb1fd862fab0b867e85aa --- /dev/null +++ b/ICL/RL_DAPO/dapo_ray_trainer.py @@ -0,0 +1,385 @@ +""" +DAPO Ray Trainer for retrieval-augmented VQA. +Extends the official DAPO RayDAPOTrainer from verl-recipe. + +Key DAPO features (vs GRPO): + - Decoupled clip ratios (clip_ratio_low, clip_ratio_high) + - No KL penalty (use_kl_in_reward=False, use_kl_loss=False) + - Dynamic sampling with group filtering (filter_groups) + - Token-level loss aggregation (loss_agg_mode=token-mean) + - Overlong reward shaping (overlong_buffer) +""" + +import os +import uuid +from collections import defaultdict +from copy import deepcopy +from pprint import pprint + +import numpy as np +import torch +from tqdm import tqdm + +from verl import DataProto +from verl.trainer.ppo.core_algos import agg_loss +from verl.trainer.ppo.metric_utils import ( + compute_data_metrics, + compute_throughout_metrics, + compute_timing_metrics, +) +from verl.trainer.ppo.ray_trainer import ( + AdvantageEstimator, + RayPPOTrainer, + apply_kl_penalty, + compute_advantage, + compute_response_mask, +) +from verl.trainer.ppo.reward import compute_reward +from verl.utils.metric import reduce_metrics +from verl.utils.profiler import marked_timer +from verl.utils.rollout_skip import RolloutSkip + + +class RayDAPOTrainer(RayPPOTrainer): + """ + DAPO Trainer for retrieval-augmented VQA. + Inherits from RayPPOTrainer, overrides fit() to implement DAPO-specific logic: + - Dynamic sampling with group filtering + - Decoupled clip (handled by verl core via clip_ratio_low/high config) + - Token-level loss (handled by verl core via loss_agg_mode config) + - No KL penalty + """ + + def compute_kl_related_metrics(self, batch: DataProto, metrics: dict, timing_raw: dict): + batch.batch["response_mask"] = compute_response_mask(batch) + + # recompute old_log_probs + with marked_timer("old_log_prob", timing_raw, "blue"): + old_log_prob = self.actor_rollout_wg.compute_log_prob(batch) + entropys = old_log_prob.batch["entropys"] + response_masks = batch.batch["response_mask"] + loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode + entropy_agg = agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode) + old_log_prob_metrics = {"actor/entropy": entropy_agg.detach().item()} + metrics.update(old_log_prob_metrics) + old_log_prob.batch.pop("entropys") + batch = batch.union(old_log_prob) + + if self.use_reference_policy: + with marked_timer("ref", timing_raw, "olive"): + if not self.ref_in_actor: + ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) + else: + ref_log_prob = self.actor_rollout_wg.compute_ref_log_prob(batch) + batch = batch.union(ref_log_prob) + + return batch + + def fit(self): + """ + The DAPO training loop. + Key differences from standard PPO: + - Group filtering: skip groups where all responses have same reward + - Dynamic sampling: keep generating until enough qualified groups + """ + from omegaconf import OmegaConf + from verl.utils.tracking import Tracking + + logger = Tracking( + project_name=self.config.trainer.project_name, + experiment_name=self.config.trainer.experiment_name, + default_backend=self.config.trainer.logger, + config=OmegaConf.to_container(self.config, resolve=True), + ) + + self.global_steps = 0 + self.gen_steps = 0 + self._load_checkpoint() + + if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True): + val_metrics = self._validate() + assert val_metrics, f"{val_metrics=}" + pprint(f"Initial validation metrics: {val_metrics}") + logger.log(data=val_metrics, step=self.global_steps) + if self.config.trainer.get("val_only", False): + return + + if self.config.actor_rollout_ref.rollout.get("skip_rollout", False): + rollout_skip = RolloutSkip(self.config, self.actor_rollout_wg) + rollout_skip.wrap_generate_sequences() + + progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="DAPO Training") + + self.global_steps += 1 + self.gen_steps += 1 + last_val_metrics = None + + prev_step_profile = False + curr_step_profile = ( + self.global_steps in self.config.global_profiler.steps + if self.config.global_profiler.steps is not None + else False + ) + next_step_profile = False + + timing_raw = defaultdict(float) + batch = None + num_prompt_in_batch = 0 + num_gen_batches = 0 + + for epoch in range(self.config.trainer.total_epochs): + for batch_dict in self.train_dataloader: + if hasattr(self.actor_rollout_wg, "async_calls_finalize_fn_exec"): + self.actor_rollout_wg.async_calls_finalize_fn_exec(blocking=False) + metrics = {} + + with marked_timer("start_profile", timing_raw): + self._start_profiling( + not prev_step_profile and curr_step_profile + if self.config.global_profiler.profile_continuous_steps + else curr_step_profile + ) + + new_batch: DataProto = DataProto.from_single_dict(batch_dict) + num_gen_batches += 1 + gen_batch = self._get_gen_batch(new_batch) + gen_batch_output = gen_batch.repeat( + repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True + ) + + is_last_step = self.global_steps >= self.total_training_steps + + with marked_timer("step", timing_raw): + # Generate a batch + with marked_timer("gen", timing_raw, "red"): + gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch_output) + timing_raw.update(gen_batch_output.meta_info["timing"]) + gen_batch_output.meta_info.pop("timing", None) + + if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX: + with marked_timer("gen_max", timing_raw, "red"): + gen_baseline_batch = deepcopy(gen_batch) + gen_baseline_batch.meta_info["do_sample"] = False + gen_baseline_output = self.async_rollout_manager.generate_sequences(gen_baseline_batch) + new_batch = new_batch.union(gen_baseline_output) + rm_scores = None + if self.use_rm and "rm_scores" not in new_batch.batch.keys(): + rm_scores = self.rm_wg.compute_rm_score(new_batch) + new_batch = new_batch.union(rm_scores) + reward_baseline_tensor, _ = compute_reward(new_batch, self.reward_fn) + reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1) + keys_to_pop = set(gen_baseline_output.batch.keys()) + if rm_scores is not None: + keys_to_pop.update(rm_scores.batch.keys()) + new_batch.pop(batch_keys=list(keys_to_pop)) + new_batch.batch["reward_baselines"] = reward_baseline_tensor + del rm_scores, gen_baseline_batch, gen_baseline_output + + new_batch.non_tensor_batch["uid"] = np.array( + [str(uuid.uuid4()) for _ in range(len(new_batch.batch))], dtype=object + ) + new_batch = new_batch.repeat( + repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True + ) + new_batch = new_batch.union(gen_batch_output) + + if self.config.algorithm.use_kl_in_reward: + new_batch = self.compute_kl_related_metrics(new_batch, metrics, timing_raw) + + with marked_timer("reward", timing_raw, "yellow"): + if self.use_rm and "rm_scores" not in new_batch.batch.keys(): + reward_tensor = self.rm_wg.compute_rm_score(new_batch) + new_batch = new_batch.union(reward_tensor) + + reward_tensor, reward_extra_infos_dict = compute_reward(new_batch, self.reward_fn) + new_batch.batch["token_level_scores"] = reward_tensor + + if reward_extra_infos_dict: + new_batch.non_tensor_batch.update( + {k: np.array(v) for k, v in reward_extra_infos_dict.items()} + ) + + if self.config.algorithm.use_kl_in_reward: + new_batch, kl_metrics = apply_kl_penalty( + new_batch, kl_ctrl=self.kl_ctrl_in_reward, + kl_penalty=self.config.algorithm.kl_penalty, + ) + metrics.update(kl_metrics) + else: + new_batch.batch["token_level_rewards"] = new_batch.batch["token_level_scores"] + + # === DAPO Dynamic Sampling with Group Filtering === + if not self.config.algorithm.filter_groups.enable: + batch = new_batch + else: + metric_name = self.config.algorithm.filter_groups.metric + if metric_name == "seq_final_reward": + new_batch.non_tensor_batch["seq_final_reward"] = ( + new_batch.batch["token_level_rewards"].sum(dim=-1).numpy() + ) + elif metric_name == "seq_reward": + new_batch.non_tensor_batch["seq_reward"] = ( + new_batch.batch["token_level_scores"].sum(dim=-1).numpy() + ) + + prompt_uid2metric_vals = defaultdict(list) + for uid, metric_val in zip( + new_batch.non_tensor_batch["uid"], + new_batch.non_tensor_batch[metric_name], + strict=True, + ): + prompt_uid2metric_vals[uid].append(metric_val) + + prompt_uid2metric_std = {} + for prompt_uid, metric_vals in prompt_uid2metric_vals.items(): + prompt_uid2metric_std[prompt_uid] = np.std(metric_vals) + + kept_prompt_uids = [ + uid + for uid, std in prompt_uid2metric_std.items() + if std > 0 or len(prompt_uid2metric_vals[uid]) == 1 + ] + num_prompt_in_batch += len(kept_prompt_uids) + + kept_traj_idxs = [] + for idx, traj_from_prompt_uid in enumerate(new_batch.non_tensor_batch["uid"]): + if traj_from_prompt_uid in kept_prompt_uids: + kept_traj_idxs.append(idx) + + new_batch = new_batch[kept_traj_idxs] + batch = new_batch if batch is None else DataProto.concat([batch, new_batch]) + + prompt_bsz = self.config.data.train_batch_size + if num_prompt_in_batch < prompt_bsz: + print(f"{num_prompt_in_batch=} < {prompt_bsz=}") + max_num_gen_batches = self.config.algorithm.filter_groups.max_num_gen_batches + if max_num_gen_batches <= 0 or num_gen_batches < max_num_gen_batches: + print(f"{num_gen_batches=}. Keep generating...") + self.gen_steps += 1 + is_last_step = self.global_steps >= self.total_training_steps + continue + else: + raise ValueError( + f"{num_gen_batches=} >= {max_num_gen_batches=}." + + " Generated too many. Please check if your data are too difficult." + + " You could also try set max_num_gen_batches=0 to enable endless trials." + ) + else: + traj_bsz = self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n + batch = batch[:traj_bsz] + + # === Updating === + if self.config.trainer.balance_batch: + self._balance_batch(batch, metrics=metrics) + + batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist() + + if not self.config.algorithm.use_kl_in_reward: + batch = self.compute_kl_related_metrics(batch, metrics, timing_raw) + + if self.use_critic: + with marked_timer("values", timing_raw, "cyan"): + values = self.critic_wg.compute_values(batch) + batch = batch.union(values) + + # Rollout correction + from verl.trainer.ppo.rollout_corr_helper import compute_rollout_correction_and_add_to_batch + + rollout_corr_config = self.config.algorithm.get("rollout_correction", None) + if rollout_corr_config is not None and "rollout_log_probs" in batch.batch: + batch, is_metrics = compute_rollout_correction_and_add_to_batch(batch, rollout_corr_config) + metrics.update(is_metrics) + + with marked_timer("adv", timing_raw, "brown"): + norm_adv_by_std_in_grpo = self.config.algorithm.get("norm_adv_by_std_in_grpo", True) + batch = compute_advantage( + batch, + adv_estimator=self.config.algorithm.adv_estimator, + gamma=self.config.algorithm.gamma, + lam=self.config.algorithm.lam, + num_repeat=self.config.actor_rollout_ref.rollout.n, + norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo, + ) + + if self.use_critic: + with marked_timer("update_critic", timing_raw, "pink"): + critic_output = self.critic_wg.update_critic(batch) + critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"]) + metrics.update(critic_output_metrics) + + if self.config.trainer.critic_warmup <= self.global_steps: + with marked_timer("update_actor", timing_raw, "red"): + actor_output = self.actor_rollout_wg.update_actor(batch) + actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"]) + metrics.update(actor_output_metrics) + + # Log rollout generations if enabled + rollout_data_dir = self.config.trainer.get("rollout_data_dir", None) + if rollout_data_dir: + self._log_rollout_data(batch, reward_extra_infos_dict, timing_raw, rollout_data_dir) + + # Validate + if ( + self.val_reward_fn is not None + and self.config.trainer.test_freq > 0 + and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0) + ): + with marked_timer("testing", timing_raw, "green"): + val_metrics: dict = self._validate() + if is_last_step: + last_val_metrics = val_metrics + metrics.update(val_metrics) + + if self.config.trainer.save_freq > 0 and ( + is_last_step or self.global_steps % self.config.trainer.save_freq == 0 + ): + with marked_timer("save_checkpoint", timing_raw, "green"): + self._save_checkpoint() + + with marked_timer("stop_profile", timing_raw): + next_step_profile = ( + self.global_steps + 1 in self.config.global_profiler.steps + if self.config.global_profiler.steps is not None + else False + ) + self._stop_profiling( + curr_step_profile and not next_step_profile + if self.config.global_profiler.profile_continuous_steps + else curr_step_profile + ) + prev_step_profile = curr_step_profile + curr_step_profile = next_step_profile + + # Collect metrics + metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic)) + metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw)) + n_gpus = self.resource_pool_manager.get_n_gpus() + metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus)) + timing_raw = defaultdict(float) + + metrics["train/num_gen_batches"] = num_gen_batches + batch = None + num_prompt_in_batch = 0 + num_gen_batches = 0 + + logger.log(data=metrics, step=self.global_steps) + + if is_last_step: + if hasattr(self.actor_rollout_wg, "async_calls_finalize_fn_exec"): + self.actor_rollout_wg.async_calls_finalize_fn_exec(blocking=True) + pprint(f"Final validation metrics: {last_val_metrics}") + progress_bar.close() + return + + progress_bar.update(1) + self.global_steps += 1 + self.gen_steps += 1 + + # Save last step checkpoint if not already saved + checkpoint_dir = os.path.join(self.config.trainer.default_local_dir, f"global_step_{self.global_steps}") + if not os.path.exists(checkpoint_dir): + timing_raw = defaultdict(float) + with marked_timer("save_checkpoint", timing_raw, "green"): + self._save_checkpoint() + metrics = {f"timing/{k}": v for k, v in timing_raw.items()} + logger.log(data=metrics, step=self.global_steps) diff --git a/ICL/RL_DAPO/environment.py b/ICL/RL_DAPO/environment.py new file mode 100644 index 0000000000000000000000000000000000000000..1861bd3417434858fc942cff8ed3bf45555efe5b --- /dev/null +++ b/ICL/RL_DAPO/environment.py @@ -0,0 +1,96 @@ +""" +Environment for DAPO rollout with SigLIP-based retrieval. +Migrated from GRPO - handles -> retrieve top-1 -> inject -> continue loop. +""" +import torch +import torch.nn.functional as F +from typing import List, Dict, Any, Tuple, Optional +from PIL import Image +import numpy as np + + +class RetrievalEnvironment: + """ + Environment that handles the retrieval loop during rollout. + Used by the reward compute_score to evaluate retrieval quality. + """ + + def __init__( + self, + siglip_model, + siglip_processor, + candidate_pools: Dict[str, List[Dict]], + max_turns: int = 3, + device: str = "cuda", + ): + self.siglip_model = siglip_model + self.siglip_processor = siglip_processor + self.candidate_pools = candidate_pools + self.max_turns = max_turns + self.device = device + self.pool_embeddings = {} + self._precompute_embeddings() + + def _precompute_embeddings(self): + """Precompute SigLIP embeddings for all candidate images.""" + print("Precomputing SigLIP embeddings for candidate pools...") + self.siglip_model.eval() + model_device = next(self.siglip_model.parameters()).device + + for dataset_name, candidates in self.candidate_pools.items(): + images = [c["image"] for c in candidates] + embeddings = [] + batch_size = 32 + for i in range(0, len(images), batch_size): + batch_images = images[i : i + batch_size] + inputs = self.siglip_processor(images=batch_images, return_tensors="pt", padding=True) + inputs = {k: v.to(model_device) for k, v in inputs.items()} + with torch.no_grad(): + image_embeds = self.siglip_model.get_image_features(**inputs) + if hasattr(image_embeds, "pooler_output"): + image_embeds = image_embeds.pooler_output + elif not isinstance(image_embeds, torch.Tensor): + image_embeds = image_embeds[0] + image_embeds = F.normalize(image_embeds, p=2, dim=-1) + embeddings.append(image_embeds.cpu()) + self.pool_embeddings[dataset_name] = torch.cat(embeddings, dim=0) + print(f" {dataset_name}: {len(candidates)} images") + + def retrieve_top1(self, query_text: str, dataset_name: str, exclude_id: str) -> Tuple[Dict, float]: + """Retrieve top-1 most similar image based on query text.""" + text_inputs = self.siglip_processor(text=[query_text], return_tensors="pt", padding=True) + model_device = next(self.siglip_model.parameters()).device + text_inputs = {k: v.to(model_device) for k, v in text_inputs.items()} + + with torch.no_grad(): + text_embeds = self.siglip_model.get_text_features(**text_inputs) + if hasattr(text_embeds, "pooler_output"): + text_embeds = text_embeds.pooler_output + elif not isinstance(text_embeds, torch.Tensor): + text_embeds = text_embeds[0] + text_embeds = F.normalize(text_embeds, p=2, dim=-1) + + image_embeds = self.pool_embeddings[dataset_name].to(text_embeds.device) + similarities = (text_embeds @ image_embeds.T).squeeze(0) + candidates = self.candidate_pools[dataset_name] + valid_indices = [i for i, c in enumerate(candidates) if c["id"] != exclude_id] + valid_similarities = similarities[valid_indices] + top_idx = valid_similarities.argmax().item() + actual_idx = valid_indices[top_idx] + top_score = valid_similarities[top_idx].item() + return candidates[actual_idx], top_score + + def compute_image_similarity(self, query_image: Image.Image, retrieved_image: Image.Image) -> float: + """Compute SigLIP image-to-image similarity for R_rel reward.""" + inputs = self.siglip_processor(images=[query_image, retrieved_image], return_tensors="pt", padding=True) + model_device = next(self.siglip_model.parameters()).device + inputs = {k: v.to(model_device) for k, v in inputs.items()} + with torch.no_grad(): + image_embeds = self.siglip_model.get_image_features(**inputs) + if hasattr(image_embeds, "pooler_output"): + image_embeds = image_embeds.pooler_output + elif not isinstance(image_embeds, torch.Tensor): + image_embeds = image_embeds[0] + image_embeds = F.normalize(image_embeds, p=2, dim=-1) + similarity = (image_embeds[0] @ image_embeds[1]).item() + return similarity diff --git a/ICL/RL_DAPO/main_dapo.py b/ICL/RL_DAPO/main_dapo.py new file mode 100644 index 0000000000000000000000000000000000000000..399f00149b5d2df7a019083f5940b9562b9ac982 --- /dev/null +++ b/ICL/RL_DAPO/main_dapo.py @@ -0,0 +1,151 @@ +""" +Main entry point for DAPO training on retrieval-augmented VQA. +Uses verl framework with custom reward function. +""" + +import os +import socket + +import hydra +import ray +from omegaconf import OmegaConf + +from verl.trainer.constants_ppo import get_ppo_ray_runtime_env +from verl.trainer.ppo.reward import load_reward_manager +from verl.utils.device import auto_set_device, is_cuda_available + +from .dapo_ray_trainer import RayDAPOTrainer + + +@hydra.main(config_path="config", config_name="dapo_trainer", version_base=None) +def main(config): + auto_set_device(config) + run_dapo(config) + + +def run_dapo(config) -> None: + if not ray.is_initialized(): + default_runtime_env = get_ppo_ray_runtime_env() + ray_init_kwargs = config.ray_kwargs.get("ray_init", {}) + runtime_env_kwargs = ray_init_kwargs.get("runtime_env", {}) + runtime_env = OmegaConf.merge(default_runtime_env, runtime_env_kwargs) + ray_init_kwargs = OmegaConf.create({**ray_init_kwargs, "runtime_env": runtime_env}) + print(f"ray init kwargs: {ray_init_kwargs}") + ray.init(**OmegaConf.to_container(ray_init_kwargs)) + + try: + if ( + is_cuda_available + and config.global_profiler.tool == "nsys" + and OmegaConf.select(config.global_profiler, "steps") is not None + and len(OmegaConf.select(config.global_profiler, "steps")) > 0 + ): + nsight_options = OmegaConf.to_container( + config.global_profiler.global_tool_config.nsys.controller_nsight_options + ) + runner = TaskRunner.options(runtime_env={"nsight": nsight_options}).remote() + else: + runner = TaskRunner.remote() + ray.get(runner.run.remote(config)) + finally: + if ray.is_initialized(): + ray.shutdown() + + +@ray.remote(num_cpus=1) +class TaskRunner: + def run(self, config): + from pprint import pprint + from omegaconf import OmegaConf + from verl.utils.fs import copy_to_local + + print(f"TaskRunner hostname: {socket.gethostname()}, PID: {os.getpid()}") + pprint(OmegaConf.to_container(config, resolve=True)) + OmegaConf.resolve(config) + + local_path = copy_to_local(config.actor_rollout_ref.model.path) + + from verl.utils import hf_processor, hf_tokenizer + + trust_remote_code = config.data.get("trust_remote_code", False) + tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code) + processor = hf_processor(local_path, trust_remote_code=trust_remote_code, use_fast=True) + + from verl.single_controller.ray import RayWorkerGroup + + if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}: + assert config.critic.strategy in {"fsdp", "fsdp2"} + from verl.workers.fsdp_workers import AsyncActorRolloutRefWorker, CriticWorker + ray_worker_group_cls = RayWorkerGroup + + elif config.actor_rollout_ref.actor.strategy == "megatron": + assert config.actor_rollout_ref.actor.strategy == config.critic.strategy + from verl.workers.megatron_workers import AsyncActorRolloutRefWorker, CriticWorker + ray_worker_group_cls = RayWorkerGroup + else: + raise NotImplementedError + + from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role + + role_worker_mapping = { + Role.ActorRollout: ray.remote(AsyncActorRolloutRefWorker), + Role.Critic: ray.remote(CriticWorker), + } + + global_pool_id = "global_pool" + resource_pool_spec = { + global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes, + } + mapping = { + Role.ActorRollout: global_pool_id, + Role.Critic: global_pool_id, + } + + if config.reward_model.enable: + if config.reward_model.strategy in {"fsdp", "fsdp2"}: + from verl.workers.fsdp_workers import RewardModelWorker + elif config.reward_model.strategy == "megatron": + from verl.workers.megatron_workers import RewardModelWorker + else: + raise NotImplementedError + role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker) + mapping[Role.RewardModel] = global_pool_id + + if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss: + role_worker_mapping[Role.RefPolicy] = ray.remote(AsyncActorRolloutRefWorker) + mapping[Role.RefPolicy] = global_pool_id + + reward_fn = load_reward_manager( + config, + tokenizer, + 0, + max_resp_len=config.data.max_response_length, + overlong_buffer_cfg=config.reward_model.overlong_buffer, + ) + + val_reward_fn = load_reward_manager( + config, + tokenizer, + 1, + max_resp_len=config.data.max_response_length, + overlong_buffer_cfg=config.reward_model.overlong_buffer, + ) + + resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) + + trainer = RayDAPOTrainer( + config=config, + tokenizer=tokenizer, + processor=processor, + role_worker_mapping=role_worker_mapping, + resource_pool_manager=resource_pool_manager, + ray_worker_group_cls=ray_worker_group_cls, + reward_fn=reward_fn, + val_reward_fn=val_reward_fn, + ) + trainer.init_workers() + trainer.fit() + + +if __name__ == "__main__": + main() diff --git a/ICL/RL_DAPO/prepare_data.py b/ICL/RL_DAPO/prepare_data.py new file mode 100644 index 0000000000000000000000000000000000000000..f854f213b5b79d961ae5a9467bd67dcd65949fe2 --- /dev/null +++ b/ICL/RL_DAPO/prepare_data.py @@ -0,0 +1,108 @@ +#!/usr/bin/env python3 +""" +Convert GRPO rl_train.jsonl to verl-compatible parquet format. +Reuses the same data distribution from GRPO (50:50 positive/negative). + +Usage: + python prepare_data.py \ + --input /workspace/xiaobin/RL_data/rl_train.jsonl \ + --output /workspace/xiaobin/RL_DAPO/data/train.parquet \ + --val-ratio 0.02 +""" +import argparse +import json +import os +from pathlib import Path + +import pandas as pd + + +SYSTEM_PROMPT = ( + "You are a visual question answering assistant. " + "You can either answer directly using or retrieve similar images using .\n" + "- Use followed by a description when you need to see similar examples.\n" + "- Use followed by your answer when you're ready to answer." +) + + +def build_prompt(question: str, image_path: str) -> str: + """ + Build a chat-template prompt string for verl. + verl expects a 'prompt' column that can be tokenized by the model's chat template. + For VL models, we embed the image path as a placeholder token. + """ + # For Qwen3-VL style, the prompt is a list of messages serialized as JSON + messages = [ + {"role": "system", "content": SYSTEM_PROMPT}, + { + "role": "user", + "content": [ + {"type": "image", "image": image_path}, + {"type": "text", "text": f"Question: {question}"}, + ], + }, + ] + return json.dumps(messages, ensure_ascii=False) + + +def main(): + parser = argparse.ArgumentParser(description="Convert GRPO JSONL to verl parquet") + parser.add_argument("--input", default="/workspace/xiaobin/RL_data/rl_train.jsonl", + help="Path to GRPO rl_train.jsonl") + parser.add_argument("--output-dir", default="/workspace/xiaobin/RL_DAPO/data", + help="Output directory for parquet files") + parser.add_argument("--val-ratio", type=float, default=0.02, + help="Fraction of data for validation") + parser.add_argument("--seed", type=int, default=42) + args = parser.parse_args() + + os.makedirs(args.output_dir, exist_ok=True) + + # Load GRPO jsonl + records = [] + with open(args.input, "r", encoding="utf-8") as f: + for line in f: + if not line.strip(): + continue + rec = json.loads(line) + records.append(rec) + + print(f"Loaded {len(records)} records from {args.input}") + + # Convert to verl format + rows = [] + for rec in records: + row = { + "prompt": build_prompt(rec["question"], rec["image"]), + "answer": rec["answer"], + "image": rec["image"], + "dataset": rec.get("subdir", "unknown"), + "id": rec["id"], + "category": rec.get("category", "unknown"), + # reward_fn_key: verl uses this to dispatch reward functions + "reward_fn_key": "retrieval_vqa", + } + rows.append(row) + + df = pd.DataFrame(rows) + + # Shuffle and split + df = df.sample(frac=1.0, random_state=args.seed).reset_index(drop=True) + n_val = max(1, int(len(df) * args.val_ratio)) + val_df = df.iloc[:n_val] + train_df = df.iloc[n_val:] + + train_path = os.path.join(args.output_dir, "train.parquet") + val_path = os.path.join(args.output_dir, "val.parquet") + + train_df.to_parquet(train_path, index=False) + val_df.to_parquet(val_path, index=False) + + print(f"Train: {len(train_df)} samples -> {train_path}") + print(f"Val: {len(val_df)} samples -> {val_path}") + print(f"Columns: {list(df.columns)}") + print("Done!") + + +if __name__ == "__main__": + main() diff --git a/ICL/RL_DAPO/requirements.txt b/ICL/RL_DAPO/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..150ba0ec59cb0f3f729f3158a7107f9c5bfa7cab --- /dev/null +++ b/ICL/RL_DAPO/requirements.txt @@ -0,0 +1,43 @@ +# ============================================================ +# RL_DAPO Dependencies +# ============================================================ +# 全部使用清华源安装: +# pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple +# +# flash-attn 用本地 whl: +# pip install /workspace/flash_attn-2.8.3+cu12torch2.4cxx11abiTRUE-cp311-cp311-linux_x86_64.whl +# +# vLLM 不需要本地安装,通过 NorthServe 远程服务 + +# --- Core --- +torch>=2.4.0 + +# --- verl (RLHF framework) --- +verl + +# --- Hugging Face --- +transformers>=4.45.0 +datasets>=2.14.0 +accelerate>=0.27.0 + +# --- Ray (verl 分布式依赖) --- +ray[default]>=2.10.0 + +# --- SigLIP 检索环境 --- +pillow>=10.0.0 +scikit-learn>=1.3.0 + +# --- Data --- +pandas>=2.0.0 +pyarrow>=14.0.0 + +# --- Config & Logging --- +hydra-core>=1.3.0 +omegaconf>=2.3.0 +wandb>=0.15.0 + +# --- General --- +numpy>=1.24.0 +tqdm>=4.65.0 +sentencepiece>=0.2.0 +protobuf>=3.20.0 diff --git a/ICL/RL_DAPO/reward_functions.py b/ICL/RL_DAPO/reward_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..b9b7353e4853415dd88e23ae101e0874f0ee8c43 --- /dev/null +++ b/ICL/RL_DAPO/reward_functions.py @@ -0,0 +1,105 @@ +""" +Reward functions for DAPO training (migrated from GRPO). +R_total = R_outcome + Gate(IsCorrect) * R_rel + R_penalty + +These are the core reward computation functions, independent of verl's framework. +""" +import re +from typing import List, Dict, Any + + +def scale_siglip_score(raw_score: float, min_th: float = 0.33, max_th: float = 0.97) -> float: + """Scale SigLIP similarity score to [0, 1] range.""" + scaled = (raw_score - min_th) / (max_th - min_th) + return max(0.0, min(1.0, scaled)) + + +def compute_f1_score(prediction: str, ground_truth: str) -> float: + """Compute F1 score between prediction and ground truth.""" + pred_tokens = set(prediction.lower().split()) + gt_tokens = set(ground_truth.lower().split()) + if len(pred_tokens) == 0 or len(gt_tokens) == 0: + return 0.0 + common = pred_tokens & gt_tokens + if len(common) == 0: + return 0.0 + precision = len(common) / len(pred_tokens) + recall = len(common) / len(gt_tokens) + return 2 * precision * recall / (precision + recall) + + +def extract_answer(text: str) -> str: + """Extract answer from model output after token.""" + if "" in text: + return text.split("")[-1].strip() + return text.strip() + + +def count_retrieval_steps(trajectory: str) -> int: + """Count number of tokens in trajectory.""" + return trajectory.count("") + + +def compute_reward( + trajectory: str, + ground_truth: str, + siglip_scores: List[float], + max_turns: int = 3, + timeout: bool = False, +) -> Dict[str, float]: + """ + Compute total reward for a trajectory. + R_total = R_outcome + R_rel + R_penalty + """ + prediction = extract_answer(trajectory) + num_steps = count_retrieval_steps(trajectory) + + # R_outcome: continuous outcome reward + f1 = compute_f1_score(prediction, ground_truth) + r_outcome = 2.0 * f1 - 1.0 + + is_correct = f1 > 0.3 + + # R_penalty + if timeout: + r_penalty = -1.5 + else: + penalty_map = {0: 0.0, 1: -0.05, 2: -0.15, 3: -0.3} + r_penalty = penalty_map.get(num_steps, -0.3) + + # R_rel: retrieval relevance + r_rel = 0.0 + if len(siglip_scores) > 0: + scaled_scores = [scale_siglip_score(score) for score in siglip_scores] + avg_rel = sum(scaled_scores) / len(scaled_scores) + r_rel = avg_rel * max(0.3, f1) * 0.5 + + r_total = r_outcome + r_rel + r_penalty + + return { + "r_outcome": r_outcome, + "r_penalty": r_penalty, + "r_rel": r_rel, + "r_total": r_total, + "f1_score": f1, + "num_steps": num_steps, + "is_correct": is_correct, + "timeout": timeout, + } + + +def batch_compute_rewards( + trajectories: List[str], + ground_truths: List[str], + siglip_scores_list: List[List[float]], + max_turns: int = 3, + timeouts: List[bool] = None, +) -> List[float]: + """Compute rewards for a batch of trajectories.""" + if timeouts is None: + timeouts = [False] * len(trajectories) + rewards = [] + for traj, gt, scores, timeout in zip(trajectories, ground_truths, siglip_scores_list, timeouts): + reward_dict = compute_reward(traj, gt, scores, max_turns, timeout) + rewards.append(reward_dict["r_total"]) + return rewards diff --git a/ICL/RL_DAPO/run_dapo.sh b/ICL/RL_DAPO/run_dapo.sh new file mode 100644 index 0000000000000000000000000000000000000000..70480c32c8e39b9574f1a918484b659a0179fa1f --- /dev/null +++ b/ICL/RL_DAPO/run_dapo.sh @@ -0,0 +1,148 @@ +#!/usr/bin/env bash +set -xeuo pipefail + +# ============================================================ +# DAPO Training for Retrieval-Augmented VQA (Qwen3-VL + SigLIP) +# Migrated from GRPO to DAPO using verl framework +# ============================================================ + +project_name='RL-DAPO-VQA' +exp_name='dapo-qwen3vl-retrieval-vqa' + +# === DAPO Algorithm Settings === +adv_estimator=grpo + +# DAPO: No KL penalty +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +# DAPO: Decoupled clip ratios (Clip-Higher) +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +# === Sequence Lengths === +max_prompt_length=1024 +max_response_length=2048 +enable_overlong_buffer=True +overlong_buffer_len=512 +overlong_penalty_factor=1.0 + +# DAPO: Token-level loss +loss_agg_mode="token-mean" + +# === Dynamic Sampling === +enable_filter_groups=True +filter_groups_metric=seq_final_reward +max_num_gen_batches=5 + +# === Batch Sizes (adjust based on your GPU count) === +train_prompt_bsz=64 +gen_prompt_bsz=${train_prompt_bsz} +n_resp_per_prompt=8 +train_prompt_mini_bsz=8 + +# === Hardware === +NNODES=${NNODES:-1} +NGPUS_PER_NODE=${NGPUS_PER_NODE:-8} + +# === Paths === +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +MODEL_PATH=${MODEL_PATH:-"/workspace/xiaobin/SFT_model/hf_qwen3vl_siglip_vqa_iter_0000881"} +CKPTS_DIR=${CKPTS_DIR:-"/workspace/xiaobin/RL_DAPO/checkpoints/${exp_name}"} +TRAIN_FILE=${TRAIN_FILE:-"/workspace/xiaobin/RL_DAPO/data/train.parquet"} +TEST_FILE=${TEST_FILE:-"/workspace/xiaobin/RL_DAPO/data/val.parquet"} + +# === Rollout Settings === +temperature=1.0 +top_p=1.0 +top_k=-1 +val_top_p=0.7 + +# === Performance === +sp_size=1 +use_dynamic_bsz=True +actor_ppo_max_token_len=$((max_prompt_length + max_response_length)) +infer_ppo_max_token_len=$((max_prompt_length + max_response_length)) +offload=True +gen_tp=${GEN_TP:-1} # Set based on model size, 1 for 7B + +# === Register custom reward function === +export VERL_CUSTOM_REWARD_FN="xiaobin.RL_DAPO.compute_score.compute_score" + +python3 -m xiaobin.RL_DAPO.main_dapo \ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + data.prompt_key=prompt \ + data.reward_fn_key=reward_fn_key \ + data.truncation='left' \ + data.trust_remote_code=true \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.gen_batch_size=${gen_prompt_bsz} \ + data.train_batch_size=${train_prompt_bsz} \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + algorithm.filter_groups.enable=${enable_filter_groups} \ + algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ + algorithm.filter_groups.metric=${filter_groups_metric} \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k=${top_k} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + reward_model.reward_manager=dapo \ + reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ + reward_model.overlong_buffer.len=${overlong_buffer_len} \ + reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ + trainer.logger='["console","wandb"]' \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.n_gpus_per_node="${NGPUS_PER_NODE}" \ + trainer.nnodes="${NNODES}" \ + trainer.val_before_train=True \ + trainer.test_freq=10 \ + trainer.save_freq=10 \ + trainer.total_epochs=3 \ + trainer.total_training_steps=200 \ + trainer.default_local_dir="${CKPTS_DIR}" \ + trainer.resume_mode=auto \ + trainer.balance_batch=True diff --git a/ICL/SFT_new/generate_captions_all.py b/ICL/SFT_new/generate_captions_all.py new file mode 100644 index 0000000000000000000000000000000000000000..6e74aa8b550b6f37e00258b43f3377312bc7b688 --- /dev/null +++ b/ICL/SFT_new/generate_captions_all.py @@ -0,0 +1,313 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +扫描 /workspace/xiaobin/dataset/images 下所有图片,调用 VLM API 生成 caption, +按 category/dataset/split 分别保存 JSON。 + +输出结构: + {output_root}/{category}/{dataset}/{split}/captions.json + +每个 captions.json 格式: + { + "model": "Qwen3-VL-8B-Instruct", + "prompt": "Describe this image ...", + "count": 12345, + "items": { + "/full/path/00000000.jpg": "A woman cutting a cake.", + "/full/path/00000001.jpg": "A red car on a street.", + ... + } + } + +断点续传: + - 已有 captions.json 且 count == 该目录图片数 → 跳过 + - 已有但不完整 → 加载已有结果,只处理缺失的图片 + - 每 --save-every 张写一次盘(防崩溃丢数据) + +用法: + # 全部处理 + python generate_captions_all.py --api-base http://10.51.6.110/v1 + + # 只处理某些 category + python generate_captions_all.py --categories vqa,captioning + + # 只处理某个 dataset + python generate_captions_all.py --datasets vqa/a-okvqa + + # 只处理某个 split + python generate_captions_all.py --datasets vqa/a-okvqa --splits train +""" + +from __future__ import annotations + +import argparse +import base64 +import json +import os +import sys +import threading +import time +from concurrent.futures import ThreadPoolExecutor, as_completed +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +from tqdm import tqdm +from openai import OpenAI + + +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- + +IMAGE_ROOT = Path("/workspace/xiaobin/dataset/images") +IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".gif", ".bmp", ".webp"} + + +# --------------------------------------------------------------------------- +# Image utilities +# --------------------------------------------------------------------------- + +def encode_image_base64(image_path: str) -> Optional[str]: + try: + with open(image_path, "rb") as f: + return base64.b64encode(f.read()).decode("utf-8") + except Exception as e: + print(f"[WARN] Failed to read {image_path}: {e}") + return None + + +def get_mime_type(image_path: str) -> str: + ext = os.path.splitext(image_path)[1].lower().lstrip(".") + return {"jpg": "image/jpeg", "jpeg": "image/jpeg", "png": "image/png", + "gif": "image/gif", "bmp": "image/bmp", "webp": "image/webp" + }.get(ext, "image/jpeg") + + +# --------------------------------------------------------------------------- +# VLM caption +# --------------------------------------------------------------------------- + +def generate_caption( + client: OpenAI, model: str, image_path: str, prompt: str, + temperature: float = 0.1, max_tokens: int = 256, +) -> Optional[str]: + b64 = encode_image_base64(image_path) + if b64 is None: + return None + mime = get_mime_type(image_path) + messages = [{ + "role": "user", + "content": [ + {"type": "text", "text": prompt}, + {"type": "image_url", "image_url": {"url": f"data:{mime};base64,{b64}"}}, + ], + }] + try: + resp = client.chat.completions.create( + model=model, messages=messages, + temperature=temperature, max_tokens=max_tokens, + ) + text = resp.choices[0].message.content + return text.strip() if text else None + except Exception as e: + print(f"[ERROR] API failed for {image_path}: {e}") + return None + + +# --------------------------------------------------------------------------- +# Discover splits +# --------------------------------------------------------------------------- + +def discover_splits( + image_root: Path, + categories: Optional[List[str]] = None, + datasets: Optional[List[str]] = None, + splits: Optional[List[str]] = None, +) -> List[Tuple[str, str, str, Path]]: + """Return list of (category, dataset, split, dir_path).""" + results = [] + for cat_dir in sorted(image_root.iterdir()): + if not cat_dir.is_dir(): + continue + cat = cat_dir.name + if categories and cat not in categories: + continue + for ds_dir in sorted(cat_dir.iterdir()): + if not ds_dir.is_dir(): + continue + ds = ds_dir.name + if datasets and f"{cat}/{ds}" not in datasets: + continue + for sp_dir in sorted(ds_dir.iterdir()): + if not sp_dir.is_dir(): + continue + sp = sp_dir.name + if splits and sp not in splits: + continue + results.append((cat, ds, sp, sp_dir)) + return results + + +def list_images(dir_path: Path) -> List[str]: + """List all image files in a directory, sorted.""" + return sorted( + str(dir_path / f.name) + for f in dir_path.iterdir() + if f.is_file() and f.suffix.lower() in IMAGE_EXTS + ) + + +# --------------------------------------------------------------------------- +# Caption cache I/O +# --------------------------------------------------------------------------- + +def load_caption_json(path: Path) -> Dict[str, str]: + if not path.exists(): + return {} + try: + with path.open("r", encoding="utf-8") as f: + data = json.load(f) + items = data.get("items", {}) + return items if isinstance(items, dict) else {} + except Exception: + return {} + + +def save_caption_json( + path: Path, items: Dict[str, str], model: str, prompt: str, +) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + data = {"model": model, "prompt": prompt, "count": len(items), "items": items} + tmp = path.with_suffix(".tmp") + with tmp.open("w", encoding="utf-8") as f: + json.dump(data, f, ensure_ascii=False) + tmp.rename(path) + + +# --------------------------------------------------------------------------- +# Process one split +# --------------------------------------------------------------------------- + +def process_split( + cat: str, ds: str, sp: str, dir_path: Path, + output_root: Path, client: OpenAI, + model: str, prompt: str, num_workers: int, + temperature: float, max_tokens: int, save_every: int, +) -> None: + out_dir = output_root / cat / ds / sp + out_path = out_dir / "captions.json" + + all_images = list_images(dir_path) + if not all_images: + print(f" [SKIP] {cat}/{ds}/{sp}: no images") + return + + existing = load_caption_json(out_path) + + if len(existing) >= len(all_images): + print(f" [SKIP] {cat}/{ds}/{sp}: done ({len(existing)}/{len(all_images)})") + return + + missing = [p for p in all_images if p not in existing] + print(f" [{cat}/{ds}/{sp}] {len(missing)} to caption " + f"({len(existing)} done, {len(all_images)} total)") + + lock = threading.Lock() + results = dict(existing) + + def _worker(img_path: str) -> None: + caption = generate_caption( + client, model, img_path, prompt, + temperature=temperature, max_tokens=max_tokens, + ) + if caption: + with lock: + results[img_path] = caption + + pbar = tqdm(total=len(missing), desc=f" {cat}/{ds}/{sp}", unit="img") + processed_since_save = 0 + + with ThreadPoolExecutor(max_workers=num_workers) as executor: + futures = {executor.submit(_worker, p): p for p in missing} + for future in as_completed(futures): + try: + future.result() + except Exception as e: + print(f"[ERROR] {e}") + pbar.update(1) + processed_since_save += 1 + if processed_since_save >= save_every: + with lock: + save_caption_json(out_path, results, model, prompt) + processed_since_save = 0 + pbar.close() + + save_caption_json(out_path, results, model, prompt) + print(f" [OK] {cat}/{ds}/{sp}: {len(results)}/{len(all_images)}") + + +# --------------------------------------------------------------------------- +# CLI +# --------------------------------------------------------------------------- + +def main() -> int: + ap = argparse.ArgumentParser(description="Batch VLM caption for all dataset images.") + ap.add_argument("--image-root", default=str(IMAGE_ROOT)) + ap.add_argument("--output-root", default="/workspace/xiaobin/dataset/detail", + help="Output root for caption JSONs") + ap.add_argument("--api-base", default="http://10.51.6.110/v1") + ap.add_argument("--api-key", default="EMPTY") + ap.add_argument("--model", default="Qwen3-VL-8B-Instruct") + ap.add_argument("--num-workers", type=int, default=128) + ap.add_argument("--prompt", default=( + "Describe this image in one or two sentences. " + "Focus on the main objects, their attributes, and spatial relationships." + )) + ap.add_argument("--temperature", type=float, default=0.1) + ap.add_argument("--max-tokens", type=int, default=256) + ap.add_argument("--save-every", type=int, default=500) + ap.add_argument("--timeout", type=float, default=600.0) + ap.add_argument("--categories", default="", + help="Comma-separated categories to process (e.g. vqa,captioning). Empty=all") + ap.add_argument("--datasets", default="", + help="Comma-separated category/dataset (e.g. vqa/a-okvqa). Empty=all") + ap.add_argument("--splits", default="", + help="Comma-separated splits (e.g. train,val). Empty=all") + args = ap.parse_args() + + image_root = Path(args.image_root) + output_root = Path(args.output_root) + + cats = [c.strip() for c in args.categories.split(",") if c.strip()] or None + dss = [d.strip() for d in args.datasets.split(",") if d.strip()] or None + sps = [s.strip() for s in args.splits.split(",") if s.strip()] or None + + all_splits = discover_splits(image_root, cats, dss, sps) + if not all_splits: + print("[ERROR] No matching splits found") + return 1 + + print(f"[INFO] {len(all_splits)} splits found") + print(f"[INFO] API: {args.api_base}, Model: {args.model}, Workers: {args.num_workers}") + print(f"[INFO] Output: {output_root}") + + client = OpenAI(api_key=args.api_key, base_url=args.api_base, timeout=args.timeout) + + for i, (cat, ds, sp, dir_path) in enumerate(all_splits, 1): + print(f"\n[{i}/{len(all_splits)}] {cat}/{ds}/{sp}") + t0 = time.time() + process_split( + cat, ds, sp, dir_path, output_root, client, + model=args.model, prompt=args.prompt, + num_workers=args.num_workers, + temperature=args.temperature, max_tokens=args.max_tokens, + save_every=args.save_every, + ) + print(f" Time: {time.time() - t0:.1f}s") + + print("\n[DONE]") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/ICL/SFT_new/run_multi_node.sh b/ICL/SFT_new/run_multi_node.sh new file mode 100644 index 0000000000000000000000000000000000000000..a1ede339d8a2e7e51275e749ea0ff4eb39dae865 --- /dev/null +++ b/ICL/SFT_new/run_multi_node.sh @@ -0,0 +1,61 @@ +#!/bin/bash +# ============================================================================= +# Multi-node training (e.g. 4 nodes × 8 GPUs = 32x H100) +# +# Launch method: northjob (your cluster scheduler) +# This script is the actual training entry point on each node. +# +# torchrun env vars set by northjob/scheduler: +# MASTER_ADDR, MASTER_PORT, WORLD_SIZE, RANK, LOCAL_RANK +# ============================================================================= +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +PYTHON_BIN="/workspace/miniconda3/envs/sft/bin/python3" + +# ---- Config ---- +MODEL_PATH="/workspace/models/Qwen3-VL-8B-Instruct" +DATA_PATH="/workspace/xiaobin/dataset/sft/all/sft.jsonl" +OUTPUT_DIR="/workspace/xiaobin/ICL/sft_model" + +NUM_GPUS_PER_NODE="${1:-8}" + +# ---- Env ---- +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True +export NCCL_P2P_DISABLE=0 +export NCCL_IB_DISABLE=0 +export NCCL_SOCKET_IFNAME=bond0 +export PYTHONPATH="${SCRIPT_DIR}:${PYTHONPATH:-}" + +# ---- Training hyperparams ---- +# global_batch = total_GPUs * micro_batch * grad_accum (e.g. 32 GPUs: 32*2*4=256) +BATCH_SIZE=2 +GRAD_ACCUM=4 +NUM_EPOCHS=3 +LR=1e-5 +MAX_LENGTH=24576 + +echo "============================================" +echo "Multi-node SFT: GPUs/node=${NUM_GPUS_PER_NODE}" +echo "MASTER_ADDR=${MASTER_ADDR:-localhost}" +echo "WORLD_SIZE=${WORLD_SIZE:-${NUM_GPUS_PER_NODE}}" +echo "RANK=${RANK:-0}" +echo "Model: ${MODEL_PATH}" +echo "Data: ${DATA_PATH}" +echo "Output: ${OUTPUT_DIR}" +echo "============================================" + +$PYTHON_BIN ${SCRIPT_DIR}/train.py \ + --model-path ${MODEL_PATH} \ + --data-path ${DATA_PATH} \ + --output-dir ${OUTPUT_DIR} \ + --deepspeed ${SCRIPT_DIR}/ds_zero3.json \ + --num-epochs ${NUM_EPOCHS} \ + --batch-size ${BATCH_SIZE} \ + --gradient-accumulation-steps ${GRAD_ACCUM} \ + --learning-rate ${LR} \ + --max-length ${MAX_LENGTH} \ + --gradient-checkpointing \ + --warmup-ratio 0.1 \ + --log-interval 1 \ + --save-interval 500