Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- ICL/RL/CHECKLIST.md +198 -0
- ICL/RL/FIX_SUMMARY.md +118 -0
- ICL/RL/README.md +167 -0
- ICL/RL/REWARD_DESIGN.md +259 -0
- ICL/RL/__pycache__/data_utils.cpython-311.pyc +0 -0
- ICL/RL/__pycache__/data_utils.cpython-313.pyc +0 -0
- ICL/RL/__pycache__/environment.cpython-311.pyc +0 -0
- ICL/RL/__pycache__/environment.cpython-313.pyc +0 -0
- ICL/RL/__pycache__/reward_functions.cpython-311.pyc +0 -0
- ICL/RL/__pycache__/reward_functions.cpython-313.pyc +0 -0
- ICL/RL/__pycache__/train_grpo.cpython-313.pyc +0 -0
- ICL/RL/build_rl_dataset.py +420 -0
- ICL/RL/config.yaml +82 -0
- ICL/RL/data_utils.py +192 -0
- ICL/RL/environment.py +240 -0
- ICL/RL/inference_example.py +157 -0
- ICL/RL/key_metrics_20260220_152053.log +10 -0
- ICL/RL/key_metrics_20260224_094601.log +0 -0
- ICL/RL/key_metrics_20260224_133510.log +0 -0
- ICL/RL/plan.md +167 -0
- ICL/RL/plot_metrics.py +127 -0
- ICL/RL/plots/clip_ratio_high.png +0 -0
- ICL/RL/plots/clip_ratio_low.png +0 -0
- ICL/RL/plots/clipped_ratio.png +0 -0
- ICL/RL/plots/completion_length.png +0 -0
- ICL/RL/plots/entropy.png +0 -0
- ICL/RL/plots/frac_reward_zero_std.png +0 -0
- ICL/RL/plots/grad_norm.png +0 -0
- ICL/RL/plots/learning_rate.png +0 -0
- ICL/RL/quickstart.sh +36 -0
- ICL/RL/requirements.txt +16 -0
- ICL/RL/reward_functions.py +135 -0
- ICL/RL/run_grpo.sh +29 -0
- ICL/RL/siglip_analysis/score_distribution.png +0 -0
- ICL/RL/test_device_config.py +111 -0
- ICL/RL/train_grpo.py +389 -0
- ICL/RL/train_grpo_20260224_133510.log +0 -0
- ICL/RL/train_pid.txt +1 -0
- ICL/RL/trl_source/.pre-commit-config.yaml +17 -0
- ICL/RL/trl_source/CITATION.cff +41 -0
- ICL/RL/trl_source/CODE_OF_CONDUCT.md +133 -0
- ICL/RL/trl_source/CONTRIBUTING.md +411 -0
- ICL/RL/trl_source/LICENSE +201 -0
- ICL/RL/trl_source/MANIFEST.in +7 -0
- ICL/RL/trl_source/Makefile +19 -0
- ICL/RL/trl_source/README.md +207 -0
- ICL/RL/trl_source/RELEASE.md +167 -0
- ICL/RL/trl_source/VERSION +1 -0
- ICL/RL/trl_source/pyproject.toml +194 -0
- ICL/RL/trl_source/requirements.txt +3 -0
ICL/RL/CHECKLIST.md
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Pre-Training Checklist
|
| 2 |
+
|
| 3 |
+
## ✅ Implementation Complete
|
| 4 |
+
|
| 5 |
+
### Core Components
|
| 6 |
+
- [x] Reward functions (R_outcome + R_rel + R_penalty)
|
| 7 |
+
- [x] Retrieval environment with SigLIP
|
| 8 |
+
- [x] Data loading (M3IT 50:50 split)
|
| 9 |
+
- [x] GRPO trainer integration
|
| 10 |
+
- [x] Self-exclusion mechanism
|
| 11 |
+
- [x] Sub-dataset candidate pools
|
| 12 |
+
- [x] Timeout handling
|
| 13 |
+
|
| 14 |
+
### Testing & Validation
|
| 15 |
+
- [x] Reward function unit tests (7/7 passed)
|
| 16 |
+
- [x] SigLIP score scaling
|
| 17 |
+
- [x] F1 score computation
|
| 18 |
+
- [x] Answer extraction
|
| 19 |
+
- [x] All test cases validated
|
| 20 |
+
|
| 21 |
+
### Documentation
|
| 22 |
+
- [x] README.md with usage guide
|
| 23 |
+
- [x] IMPLEMENTATION_SUMMARY.md
|
| 24 |
+
- [x] Code comments and docstrings
|
| 25 |
+
- [x] Configuration file (config.yaml)
|
| 26 |
+
- [x] Quick start script
|
| 27 |
+
|
| 28 |
+
### Utilities
|
| 29 |
+
- [x] Test script (test_rewards.py)
|
| 30 |
+
- [x] Visualization script (visualize_rewards.py)
|
| 31 |
+
- [x] Inference example (inference_example.py)
|
| 32 |
+
- [x] Requirements.txt
|
| 33 |
+
|
| 34 |
+
## 🚦 Before Training
|
| 35 |
+
|
| 36 |
+
### 1. Environment Setup
|
| 37 |
+
```bash
|
| 38 |
+
cd /workspace/xiaobin/RL
|
| 39 |
+
bash quickstart.sh
|
| 40 |
+
```
|
| 41 |
+
|
| 42 |
+
### 2. Verify Model Path
|
| 43 |
+
Check that SFT checkpoint exists:
|
| 44 |
+
```bash
|
| 45 |
+
ls -lh /workspace/xiaobin/SFT_model/hf_qwen3vl_siglip_vqa_iter_0000881
|
| 46 |
+
```
|
| 47 |
+
|
| 48 |
+
### 3. Test Reward Functions
|
| 49 |
+
```bash
|
| 50 |
+
python test_rewards.py
|
| 51 |
+
```
|
| 52 |
+
Expected: All tests pass ✓
|
| 53 |
+
|
| 54 |
+
### 4. Configure Training
|
| 55 |
+
Edit `train_grpo.py`:
|
| 56 |
+
- [ ] MODEL_PATH: Verify SFT checkpoint path
|
| 57 |
+
- [ ] OUTPUT_DIR: Set output directory
|
| 58 |
+
- [ ] LEARNING_RATE: Default 1e-5
|
| 59 |
+
- [ ] NUM_GENERATIONS: Default 8 (group size)
|
| 60 |
+
- [ ] MAX_TURNS: Default 3
|
| 61 |
+
- [ ] MAX_SAMPLES_PER_DATASET: Set to small number for testing, None for full
|
| 62 |
+
|
| 63 |
+
### 5. Hardware Check
|
| 64 |
+
- [ ] GPU available: `nvidia-smi`
|
| 65 |
+
- [ ] VRAM: Minimum 40GB (A100), recommended 80GB (H100)
|
| 66 |
+
- [ ] Disk space: Check output directory has space
|
| 67 |
+
|
| 68 |
+
### 6. Data Access
|
| 69 |
+
- [ ] M3IT dataset accessible via HuggingFace
|
| 70 |
+
- [ ] Internet connection for dataset download
|
| 71 |
+
- [ ] Sufficient disk space for dataset cache
|
| 72 |
+
|
| 73 |
+
### 7. Wandb Setup (Optional)
|
| 74 |
+
```bash
|
| 75 |
+
wandb login
|
| 76 |
+
```
|
| 77 |
+
Or set `USE_WANDB = False` in train_grpo.py
|
| 78 |
+
|
| 79 |
+
## 🎯 Training Stages
|
| 80 |
+
|
| 81 |
+
### Stage 1: Smoke Test (Recommended)
|
| 82 |
+
```python
|
| 83 |
+
# In train_grpo.py, set:
|
| 84 |
+
MAX_SAMPLES_PER_DATASET = 100 # Small dataset
|
| 85 |
+
NUM_EPOCHS = 1
|
| 86 |
+
```
|
| 87 |
+
Run: `python train_grpo.py`
|
| 88 |
+
|
| 89 |
+
Expected: Training completes without errors
|
| 90 |
+
|
| 91 |
+
### Stage 2: Full Training
|
| 92 |
+
```python
|
| 93 |
+
# In train_grpo.py, set:
|
| 94 |
+
MAX_SAMPLES_PER_DATASET = None # Full dataset
|
| 95 |
+
NUM_EPOCHS = 3
|
| 96 |
+
```
|
| 97 |
+
Run: `python train_grpo.py`
|
| 98 |
+
|
| 99 |
+
## 📊 Monitoring
|
| 100 |
+
|
| 101 |
+
### During Training
|
| 102 |
+
- [ ] Check wandb dashboard for metrics
|
| 103 |
+
- [ ] Monitor GPU utilization: `watch -n 1 nvidia-smi`
|
| 104 |
+
- [ ] Check logs in output/ directory
|
| 105 |
+
- [ ] Verify checkpoints are being saved
|
| 106 |
+
|
| 107 |
+
### Key Metrics to Watch
|
| 108 |
+
- **Average reward**: Should increase over time
|
| 109 |
+
- **Reward distribution**: Check positive vs negative samples
|
| 110 |
+
- **Retrieval rate**: % of samples using <RET>
|
| 111 |
+
- **Average steps**: Should decrease (efficiency)
|
| 112 |
+
- **Timeout rate**: Should decrease
|
| 113 |
+
|
| 114 |
+
## 🔍 Expected Behavior
|
| 115 |
+
|
| 116 |
+
### Early Training (Epoch 1)
|
| 117 |
+
- Random retrieval decisions
|
| 118 |
+
- High timeout rate
|
| 119 |
+
- Low average reward
|
| 120 |
+
- Inconsistent step counts
|
| 121 |
+
|
| 122 |
+
### Mid Training (Epoch 2)
|
| 123 |
+
- Learning to distinguish positive/negative
|
| 124 |
+
- Decreasing timeout rate
|
| 125 |
+
- Improving average reward
|
| 126 |
+
- More consistent behavior
|
| 127 |
+
|
| 128 |
+
### Late Training (Epoch 3)
|
| 129 |
+
- Clear retrieval strategy
|
| 130 |
+
- Low timeout rate
|
| 131 |
+
- High average reward
|
| 132 |
+
- Efficient step usage (prefer 1-shot)
|
| 133 |
+
|
| 134 |
+
## 🐛 Troubleshooting
|
| 135 |
+
|
| 136 |
+
### OOM (Out of Memory)
|
| 137 |
+
```python
|
| 138 |
+
BATCH_SIZE = 1 # Already minimal
|
| 139 |
+
NUM_GENERATIONS = 4 # Reduce from 8
|
| 140 |
+
gradient_accumulation_steps = 8 # Increase
|
| 141 |
+
```
|
| 142 |
+
|
| 143 |
+
### Slow Training
|
| 144 |
+
```python
|
| 145 |
+
MAX_SAMPLES_PER_DATASET = 1000 # Limit dataset size
|
| 146 |
+
```
|
| 147 |
+
|
| 148 |
+
### Not Learning
|
| 149 |
+
- Check reward distribution in wandb
|
| 150 |
+
- Verify data balance (50:50)
|
| 151 |
+
- Check SigLIP embeddings are precomputed
|
| 152 |
+
- Verify self-exclusion is working
|
| 153 |
+
|
| 154 |
+
### Data Loading Errors
|
| 155 |
+
- Check internet connection
|
| 156 |
+
- Clear HuggingFace cache: `rm -rf ~/.cache/huggingface`
|
| 157 |
+
- Try loading datasets individually
|
| 158 |
+
|
| 159 |
+
## 📝 Post-Training
|
| 160 |
+
|
| 161 |
+
### 1. Evaluate Model
|
| 162 |
+
```bash
|
| 163 |
+
python inference_example.py
|
| 164 |
+
```
|
| 165 |
+
|
| 166 |
+
### 2. Analyze Results
|
| 167 |
+
- Check final reward distribution
|
| 168 |
+
- Analyze retrieval patterns
|
| 169 |
+
- Compare positive vs negative samples
|
| 170 |
+
- Measure efficiency (avg steps)
|
| 171 |
+
|
| 172 |
+
### 3. Save Best Checkpoint
|
| 173 |
+
```bash
|
| 174 |
+
cp -r output/checkpoint-XXXX output/best_model
|
| 175 |
+
```
|
| 176 |
+
|
| 177 |
+
## ✨ Success Criteria
|
| 178 |
+
|
| 179 |
+
Training is successful if:
|
| 180 |
+
- [x] Average reward > 0.5
|
| 181 |
+
- [x] Timeout rate < 5%
|
| 182 |
+
- [x] Positive samples: >70% use retrieval
|
| 183 |
+
- [x] Negative samples: >70% direct answer
|
| 184 |
+
- [x] Average steps: 1.0-1.5 (efficient)
|
| 185 |
+
|
| 186 |
+
## 🎓 Next Steps After Training
|
| 187 |
+
|
| 188 |
+
1. Evaluate on held-out test set
|
| 189 |
+
2. Analyze failure cases
|
| 190 |
+
3. Fine-tune hyperparameters if needed
|
| 191 |
+
4. Deploy model for inference
|
| 192 |
+
5. Collect user feedback
|
| 193 |
+
|
| 194 |
+
---
|
| 195 |
+
|
| 196 |
+
**Status**: Ready for training ✅
|
| 197 |
+
|
| 198 |
+
All components tested and validated. Proceed with smoke test first, then full training.
|
ICL/RL/FIX_SUMMARY.md
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# GRPO训练问题修复总结
|
| 2 |
+
|
| 3 |
+
## 发现的问题
|
| 4 |
+
|
| 5 |
+
### 1. 多GPU设备不匹配错误 ❌
|
| 6 |
+
**问题描述**:
|
| 7 |
+
- VLM和SigLIP模型使用`device_map="auto"`自动分配到多个GPU (cuda:0, cuda:1)
|
| 8 |
+
- Environment类中硬编码`device="cuda"`(默认cuda:0)
|
| 9 |
+
- 导致在`retrieve_top1`中:
|
| 10 |
+
- `text_embeds`可能在cuda:1
|
| 11 |
+
- `image_embeds`被移到cuda:0
|
| 12 |
+
- 矩阵乘法时设备不匹配: `Expected all tensors to be on the same device, but got mat2 is on cuda:0, different from other tensors on cuda:1`
|
| 13 |
+
|
| 14 |
+
**修复方案**:
|
| 15 |
+
- 不再使用硬编码的`self.device`
|
| 16 |
+
- 改为动态获取模型参数所在设备: `model_device = next(self.siglip_model.parameters()).device`
|
| 17 |
+
- 所有输入张量都移到模型所在设备
|
| 18 |
+
- 确保所有计算在同一设备上进行
|
| 19 |
+
|
| 20 |
+
**修改文件**: `environment.py`
|
| 21 |
+
- `_precompute_embeddings()`: 使用模型设备而非self.device
|
| 22 |
+
- `retrieve_top1()`: 动态获取模型设备,确保text_embeds和image_embeds在同一设备
|
| 23 |
+
- `compute_image_similarity()`: 使用模型设备
|
| 24 |
+
|
| 25 |
+
### 2. Loss接近0,模型没有学习 ❌
|
| 26 |
+
**问题描述**:
|
| 27 |
+
- 训练中后期loss在-0.0013到0.0021之间,接近0
|
| 28 |
+
- `frac_reward_zero_std`: 0.875-0.975 (87.5%-97.5%的样本reward标准差为0)
|
| 29 |
+
- 说明对于同一个prompt,16个生成的completions获得了几乎相同的reward
|
| 30 |
+
- GRPO的优势函数 = reward - baseline,如果所有reward相同,优势为0,loss为0
|
| 31 |
+
|
| 32 |
+
**根本原因**:
|
| 33 |
+
- `temperature=0.7`太低,导致生成的16个completions太相似
|
| 34 |
+
- 缺乏多样性导致reward方差为0
|
| 35 |
+
|
| 36 |
+
**修复方案**:
|
| 37 |
+
- 将temperature从0.7提高到1.1
|
| 38 |
+
- 增加生成多样性,使不同completions获得不同reward
|
| 39 |
+
- 这样GRPO才能正确计算优势函数并学习
|
| 40 |
+
|
| 41 |
+
**修改文件**: `train_grpo.py`
|
| 42 |
+
- `GRPOConfig.temperature`: 0.7 → 1.1
|
| 43 |
+
|
| 44 |
+
### 3. Retrieval错误频繁出现 ❌
|
| 45 |
+
**问题描述**:
|
| 46 |
+
- 日志中大量"Retrieval error"信息
|
| 47 |
+
- 所有错误都是设备不匹配导致的
|
| 48 |
+
|
| 49 |
+
**修复方案**:
|
| 50 |
+
- 通过修复问题1(多GPU设备配置),这个问题会自动解决
|
| 51 |
+
|
| 52 |
+
## 修改的代码
|
| 53 |
+
|
| 54 |
+
### environment.py
|
| 55 |
+
|
| 56 |
+
#### 1. `_precompute_embeddings()`
|
| 57 |
+
```python
|
| 58 |
+
# 修改前
|
| 59 |
+
inputs = self.siglip_processor(...).to(self.device)
|
| 60 |
+
|
| 61 |
+
# 修改后
|
| 62 |
+
model_device = next(self.siglip_model.parameters()).device
|
| 63 |
+
inputs = {k: v.to(model_device) for k, v in inputs.items()}
|
| 64 |
+
```
|
| 65 |
+
|
| 66 |
+
#### 2. `retrieve_top1()`
|
| 67 |
+
```python
|
| 68 |
+
# 修改前
|
| 69 |
+
text_inputs = self.siglip_processor(...).to(self.device)
|
| 70 |
+
image_embeds = self.pool_embeddings[dataset_name].to(self.device)
|
| 71 |
+
|
| 72 |
+
# 修改后
|
| 73 |
+
model_device = next(self.siglip_model.parameters()).device
|
| 74 |
+
text_inputs = {k: v.to(model_device) for k, v in text_inputs.items()}
|
| 75 |
+
image_embeds = self.pool_embeddings[dataset_name].to(text_embeds.device)
|
| 76 |
+
```
|
| 77 |
+
|
| 78 |
+
#### 3. `compute_image_similarity()`
|
| 79 |
+
```python
|
| 80 |
+
# 修改前
|
| 81 |
+
inputs = self.siglip_processor(...).to(self.device)
|
| 82 |
+
|
| 83 |
+
# 修改后
|
| 84 |
+
model_device = next(self.siglip_model.parameters()).device
|
| 85 |
+
inputs = {k: v.to(model_device) for k, v in inputs.items()}
|
| 86 |
+
```
|
| 87 |
+
|
| 88 |
+
### train_grpo.py
|
| 89 |
+
|
| 90 |
+
```python
|
| 91 |
+
# 修改前
|
| 92 |
+
temperature=0.7,
|
| 93 |
+
|
| 94 |
+
# 修改后
|
| 95 |
+
temperature=1.1, # Increased from 0.7 to 1.1 for more diversity
|
| 96 |
+
```
|
| 97 |
+
|
| 98 |
+
## 验证
|
| 99 |
+
|
| 100 |
+
运行测试脚本验证设备配置:
|
| 101 |
+
```bash
|
| 102 |
+
cd /workspace/xiaobin/RL
|
| 103 |
+
python test_device_config.py
|
| 104 |
+
```
|
| 105 |
+
|
| 106 |
+
## 预期效果
|
| 107 |
+
|
| 108 |
+
1. **不再有设备不匹配错误**: 所有retrieval操作都能正常执行
|
| 109 |
+
2. **Loss不再接近0**: 由于生成多样性增加,reward方差增大,GRPO能正常学习
|
| 110 |
+
3. **训练稳定性提升**: 不会因为设备错误导致OOM或崩溃
|
| 111 |
+
4. **模型性能提升**: 能够正确学习retrieval策略
|
| 112 |
+
|
| 113 |
+
## 其他建议
|
| 114 |
+
|
| 115 |
+
1. **监控reward分布**: 观察`frac_reward_zero_std`是否降低到0.5以下
|
| 116 |
+
2. **调整temperature**: 如果1.1还不够,可以尝试1.2-1.5
|
| 117 |
+
3. **检查生成质量**: 确保temperature提高后生成的文本仍然合理
|
| 118 |
+
4. **保存checkpoint**: 定期保存模型以防训练中断
|
ICL/RL/README.md
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# GRPO训练:检索增强视觉问答
|
| 2 |
+
|
| 3 |
+
本目录包含基于GRPO(Group Relative Policy Optimization)的视觉语言模型训练实现,让模型学会何时检索额外图片、何时直接回答。
|
| 4 |
+
|
| 5 |
+
## 概述
|
| 6 |
+
|
| 7 |
+
模型学习:
|
| 8 |
+
- 需要时输出 `<RET> 描述文本` 来检索相似图片
|
| 9 |
+
- 准备好时输出 `<ANS> 答案` 直接回答
|
| 10 |
+
- 优化检索效率(步数越少越好)
|
| 11 |
+
- 最大化答案准确率
|
| 12 |
+
|
| 13 |
+
## 文件说明
|
| 14 |
+
|
| 15 |
+
- `train_grpo.py` - 主训练脚本
|
| 16 |
+
- `reward_functions.py` - 奖励计算(R_outcome + R_rel + R_penalty)
|
| 17 |
+
- `environment.py` - 基于SigLIP的检索环境
|
| 18 |
+
- `data_utils.py` - M3IT数据加载和预处理
|
| 19 |
+
- `requirements.txt` - Python依赖
|
| 20 |
+
- `config.yaml` - 训练配置文件
|
| 21 |
+
|
| 22 |
+
## 安装
|
| 23 |
+
|
| 24 |
+
```bash
|
| 25 |
+
pip install -r requirements.txt
|
| 26 |
+
```
|
| 27 |
+
|
| 28 |
+
## 数据策略
|
| 29 |
+
|
| 30 |
+
训练使用M3IT数据集,分为两类:
|
| 31 |
+
|
| 32 |
+
**正样本(50%)** - 需要检索:
|
| 33 |
+
- OK-VQA / A-OKVQA(知识密集型)
|
| 34 |
+
- ScienceQA(科学推理)
|
| 35 |
+
|
| 36 |
+
**负样本(50%)** - 直接回答:
|
| 37 |
+
- TextVQA / OCR-VQA(读图中文字)
|
| 38 |
+
- VQA-v2(基础视觉识别)
|
| 39 |
+
- CLEVR(纯逻辑推理)
|
| 40 |
+
|
| 41 |
+
## 奖励函数
|
| 42 |
+
|
| 43 |
+
```
|
| 44 |
+
R_total = R_outcome + Gate(答对了?) × R_rel + R_penalty
|
| 45 |
+
```
|
| 46 |
+
|
| 47 |
+
- **R_outcome**:答对+1.0,答错-1.0
|
| 48 |
+
- **R_penalty**:0步/1步/2步/3步分别为 0/-0.1/-0.25/-0.5;超时-2.0
|
| 49 |
+
- **R_rel**:归一化的SigLIP图片相似度(仅在答对时生效)
|
| 50 |
+
|
| 51 |
+
## 使用方法
|
| 52 |
+
|
| 53 |
+
### 基础训练
|
| 54 |
+
|
| 55 |
+
```bash
|
| 56 |
+
python train_grpo.py
|
| 57 |
+
```
|
| 58 |
+
|
| 59 |
+
### 配置参数
|
| 60 |
+
|
| 61 |
+
在 `train_grpo.py` 中编辑配置:
|
| 62 |
+
|
| 63 |
+
```python
|
| 64 |
+
MODEL_PATH = "/workspace/xiaobin/SFT_model/hf_qwen3vl_siglip_vqa_iter_0000881"
|
| 65 |
+
OUTPUT_DIR = "/workspace/xiaobin/RL/output"
|
| 66 |
+
SIGLIP_MODEL = "/workspace/siglip2-so400m-patch16-naflex"
|
| 67 |
+
LEARNING_RATE = 1e-5
|
| 68 |
+
BATCH_SIZE = 1
|
| 69 |
+
NUM_GENERATIONS = 8 # 组大小
|
| 70 |
+
MAX_TURNS = 3
|
| 71 |
+
```
|
| 72 |
+
|
| 73 |
+
### 监控训练
|
| 74 |
+
|
| 75 |
+
训练日志会发送到Weights & Biases。设置 `USE_WANDB = False` 可禁用。
|
| 76 |
+
|
| 77 |
+
## 检索环境
|
| 78 |
+
|
| 79 |
+
环境实现了while循环交互:
|
| 80 |
+
|
| 81 |
+
1. 模型生成输出
|
| 82 |
+
2. 如果检测到 `<RET>`:
|
| 83 |
+
- 提取描述文本
|
| 84 |
+
- 用SigLIP找到最相似的图片(排除查询图片)
|
| 85 |
+
- 将检索到的图片注入对话
|
| 86 |
+
- 继续下一轮
|
| 87 |
+
3. 如果检测到 `<ANS>` 或达到最大轮次:
|
| 88 |
+
- 结束轨迹
|
| 89 |
+
- 计算奖励
|
| 90 |
+
|
| 91 |
+
## 核心特性
|
| 92 |
+
|
| 93 |
+
- **子数据集检索池**:每个数据集有自己的候选池,检索更快更准
|
| 94 |
+
- **自排除机制**:查询图片总是被排除,防止平凡解
|
| 95 |
+
- **非线性步数惩罚**:鼓励效率(1步 > 3步)
|
| 96 |
+
- **门控相关性奖励**:R_rel仅在答对时生效
|
| 97 |
+
- **超时处理**:死循环重罚(-2.0)
|
| 98 |
+
|
| 99 |
+
## 预期行为
|
| 100 |
+
|
| 101 |
+
训练后,模型应该:
|
| 102 |
+
- 对知识密集型问题(OK-VQA)进行检索
|
| 103 |
+
- 对视觉问题(TextVQA、VQA-v2)直接回答
|
| 104 |
+
- 使用最少的检索步数(优先1步而非3步)
|
| 105 |
+
- 写出准确的描述来检索相关图片
|
| 106 |
+
|
| 107 |
+
## 硬件要求
|
| 108 |
+
|
| 109 |
+
- 推荐:8×H100(80GB)或同等配置
|
| 110 |
+
- 最低:1×A100(40GB),需减小batch size
|
| 111 |
+
- SigLIP预计算需要约10GB显存
|
| 112 |
+
|
| 113 |
+
## 故障排除
|
| 114 |
+
|
| 115 |
+
**显存不足**:减小 `BATCH_SIZE` 或 `NUM_GENERATIONS`
|
| 116 |
+
|
| 117 |
+
**检索太慢**:调试时减小 `MAX_SAMPLES_PER_DATASET`
|
| 118 |
+
|
| 119 |
+
**模型不学习**:检查wandb日志中的奖励分布
|
| 120 |
+
|
| 121 |
+
## 测试验证
|
| 122 |
+
|
| 123 |
+
运行测试脚本验证实现:
|
| 124 |
+
|
| 125 |
+
```bash
|
| 126 |
+
# 测试奖励函数
|
| 127 |
+
python test_rewards.py
|
| 128 |
+
|
| 129 |
+
# 可视化奖励分布
|
| 130 |
+
python visualize_rewards.py
|
| 131 |
+
```
|
| 132 |
+
|
| 133 |
+
## 训练阶段
|
| 134 |
+
|
| 135 |
+
### 第一阶段:冒烟测试(推荐)
|
| 136 |
+
```python
|
| 137 |
+
MAX_SAMPLES_PER_DATASET = 100 # 小数据集
|
| 138 |
+
NUM_EPOCHS = 1
|
| 139 |
+
```
|
| 140 |
+
|
| 141 |
+
### 第二阶段:完整训练
|
| 142 |
+
```python
|
| 143 |
+
MAX_SAMPLES_PER_DATASET = None # 完整数据集
|
| 144 |
+
NUM_EPOCHS = 3
|
| 145 |
+
```
|
| 146 |
+
|
| 147 |
+
## 监控指标
|
| 148 |
+
|
| 149 |
+
训练过程中关注:
|
| 150 |
+
- **平均奖励**:应随时间增加
|
| 151 |
+
- **奖励分布**:检查正负样本的差异
|
| 152 |
+
- **检索率**:使用 `<RET>` 的样本百分比
|
| 153 |
+
- **平均步数**:应该减少(提高效率)
|
| 154 |
+
- **超时率**:应该减少
|
| 155 |
+
|
| 156 |
+
## 成功标准
|
| 157 |
+
|
| 158 |
+
训练成功的标志:
|
| 159 |
+
- 平均奖励 > 0.5
|
| 160 |
+
- 超时率 < 5%
|
| 161 |
+
- 正样本:>70% 使用检索
|
| 162 |
+
- 负样本:>70% 直接回答
|
| 163 |
+
- 平均步数:1.0-1.5(高效)
|
| 164 |
+
|
| 165 |
+
## 引用
|
| 166 |
+
|
| 167 |
+
基于TRL库的GRPO算法和 `plan.md` 中的方案实现。
|
ICL/RL/REWARD_DESIGN.md
ADDED
|
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Reward Function Design
|
| 2 |
+
|
| 3 |
+
## Overview
|
| 4 |
+
|
| 5 |
+
The reward function is designed to train a retrieval-augmented VQA model using GRPO (Group Relative Policy Optimization). The model can either answer directly or retrieve similar images to help answer questions.
|
| 6 |
+
|
| 7 |
+
## Formula
|
| 8 |
+
|
| 9 |
+
```
|
| 10 |
+
R_total = R_outcome + R_rel + R_penalty
|
| 11 |
+
```
|
| 12 |
+
|
| 13 |
+
Where:
|
| 14 |
+
- **R_outcome**: Answer correctness reward
|
| 15 |
+
- **R_rel**: Retrieval relevance reward (gated by correctness)
|
| 16 |
+
- **R_penalty**: Step penalty to discourage unnecessary retrieval
|
| 17 |
+
|
| 18 |
+
---
|
| 19 |
+
|
| 20 |
+
## Component Details
|
| 21 |
+
|
| 22 |
+
### 1. R_outcome (Answer Correctness)
|
| 23 |
+
|
| 24 |
+
Measures whether the final answer is correct using F1 score.
|
| 25 |
+
|
| 26 |
+
```python
|
| 27 |
+
f1 = compute_f1_score(prediction, ground_truth)
|
| 28 |
+
is_correct = (f1 > 0.5)
|
| 29 |
+
|
| 30 |
+
r_outcome = +1.0 if is_correct
|
| 31 |
+
-1.0 otherwise
|
| 32 |
+
```
|
| 33 |
+
|
| 34 |
+
**Purpose**: Primary signal - the model must answer correctly.
|
| 35 |
+
|
| 36 |
+
---
|
| 37 |
+
|
| 38 |
+
### 2. R_penalty (Step Penalty)
|
| 39 |
+
|
| 40 |
+
Penalizes the number of retrieval steps to encourage efficiency.
|
| 41 |
+
|
| 42 |
+
```python
|
| 43 |
+
num_steps = count("<RET>", trajectory)
|
| 44 |
+
|
| 45 |
+
r_penalty = {
|
| 46 |
+
0 steps: 0.0, # Direct answer, no penalty
|
| 47 |
+
1 step: -0.1, # Small penalty
|
| 48 |
+
2 steps: -0.25, # Medium penalty
|
| 49 |
+
3 steps: -0.5, # Large penalty
|
| 50 |
+
timeout: -2.0 # Death penalty for infinite loop
|
| 51 |
+
}
|
| 52 |
+
```
|
| 53 |
+
|
| 54 |
+
**Purpose**: Encourage the model to use retrieval only when necessary.
|
| 55 |
+
|
| 56 |
+
**Timeout**: Occurs when the model exceeds `max_turns` (default: 3) without producing `<ANS>`.
|
| 57 |
+
|
| 58 |
+
---
|
| 59 |
+
|
| 60 |
+
### 3. R_rel (Retrieval Relevance)
|
| 61 |
+
|
| 62 |
+
Rewards retrieving relevant images, **only if the answer is correct** (gate mechanism).
|
| 63 |
+
|
| 64 |
+
```python
|
| 65 |
+
if is_correct and len(siglip_scores) > 0:
|
| 66 |
+
# Scale SigLIP scores from [0.33, 0.97] to [0, 1]
|
| 67 |
+
scaled_scores = [scale_siglip_score(s) for s in siglip_scores]
|
| 68 |
+
r_rel = mean(scaled_scores)
|
| 69 |
+
else:
|
| 70 |
+
r_rel = 0.0
|
| 71 |
+
```
|
| 72 |
+
|
| 73 |
+
**SigLIP Score Scaling**:
|
| 74 |
+
```python
|
| 75 |
+
def scale_siglip_score(raw_score, min_th=0.33, max_th=0.97):
|
| 76 |
+
scaled = (raw_score - min_th) / (max_th - min_th)
|
| 77 |
+
return clip(scaled, 0.0, 1.0)
|
| 78 |
+
```
|
| 79 |
+
|
| 80 |
+
**Purpose**:
|
| 81 |
+
- Encourage retrieving images similar to the query image
|
| 82 |
+
- Only reward good retrieval if it leads to correct answers
|
| 83 |
+
- Thresholds based on data analysis: P05=0.33, P95=0.97 for same-dataset pairs
|
| 84 |
+
|
| 85 |
+
---
|
| 86 |
+
|
| 87 |
+
## Reward Examples
|
| 88 |
+
|
| 89 |
+
### Example 1: Direct Answer (Correct)
|
| 90 |
+
```
|
| 91 |
+
Trajectory: "<ANS> cat"
|
| 92 |
+
Ground Truth: "cat"
|
| 93 |
+
```
|
| 94 |
+
- r_outcome = +1.0 (correct)
|
| 95 |
+
- r_penalty = 0.0 (no retrieval)
|
| 96 |
+
- r_rel = 0.0 (no retrieval)
|
| 97 |
+
- **R_total = +1.0**
|
| 98 |
+
|
| 99 |
+
---
|
| 100 |
+
|
| 101 |
+
### Example 2: One Retrieval (Correct, High Similarity)
|
| 102 |
+
```
|
| 103 |
+
Trajectory: "<RET> a photo of a cat <ANS> cat"
|
| 104 |
+
Ground Truth: "cat"
|
| 105 |
+
SigLIP Score: 0.85
|
| 106 |
+
```
|
| 107 |
+
- r_outcome = +1.0 (correct)
|
| 108 |
+
- r_penalty = -0.1 (1 retrieval)
|
| 109 |
+
- r_rel = scale(0.85) = (0.85-0.33)/(0.97-0.33) ≈ 0.81
|
| 110 |
+
- **R_total = +1.71**
|
| 111 |
+
|
| 112 |
+
---
|
| 113 |
+
|
| 114 |
+
### Example 3: Two Retrievals (Correct, Mixed Similarity)
|
| 115 |
+
```
|
| 116 |
+
Trajectory: "<RET> animal <RET> cat photo <ANS> cat"
|
| 117 |
+
Ground Truth: "cat"
|
| 118 |
+
SigLIP Scores: [0.65, 0.90]
|
| 119 |
+
```
|
| 120 |
+
- r_outcome = +1.0 (correct)
|
| 121 |
+
- r_penalty = -0.25 (2 retrievals)
|
| 122 |
+
- r_rel = mean([scale(0.65), scale(0.90)]) = mean([0.50, 0.89]) ≈ 0.70
|
| 123 |
+
- **R_total = +1.45**
|
| 124 |
+
|
| 125 |
+
---
|
| 126 |
+
|
| 127 |
+
### Example 4: One Retrieval (Incorrect)
|
| 128 |
+
```
|
| 129 |
+
Trajectory: "<RET> a photo of a dog <ANS> dog"
|
| 130 |
+
Ground Truth: "cat"
|
| 131 |
+
SigLIP Score: 0.75
|
| 132 |
+
```
|
| 133 |
+
- r_outcome = -1.0 (incorrect)
|
| 134 |
+
- r_penalty = -0.1 (1 retrieval)
|
| 135 |
+
- r_rel = 0.0 (gated out because incorrect)
|
| 136 |
+
- **R_total = -1.1**
|
| 137 |
+
|
| 138 |
+
---
|
| 139 |
+
|
| 140 |
+
### Example 5: Direct Answer (Incorrect)
|
| 141 |
+
```
|
| 142 |
+
Trajectory: "<ANS> dog"
|
| 143 |
+
Ground Truth: "cat"
|
| 144 |
+
```
|
| 145 |
+
- r_outcome = -1.0 (incorrect)
|
| 146 |
+
- r_penalty = 0.0 (no retrieval)
|
| 147 |
+
- r_rel = 0.0 (no retrieval)
|
| 148 |
+
- **R_total = -1.0**
|
| 149 |
+
|
| 150 |
+
---
|
| 151 |
+
|
| 152 |
+
### Example 6: Timeout (Infinite Loop)
|
| 153 |
+
```
|
| 154 |
+
Trajectory: "<RET> query1 <RET> query2 <RET> query3"
|
| 155 |
+
Ground Truth: "cat"
|
| 156 |
+
Max Turns: 3
|
| 157 |
+
```
|
| 158 |
+
- r_outcome = -1.0 (no answer)
|
| 159 |
+
- r_penalty = -2.0 (timeout)
|
| 160 |
+
- r_rel = 0.0 (incorrect)
|
| 161 |
+
- **R_total = -3.0**
|
| 162 |
+
|
| 163 |
+
---
|
| 164 |
+
|
| 165 |
+
## Design Rationale
|
| 166 |
+
|
| 167 |
+
### 1. R_outcome Dominates
|
| 168 |
+
The ±1.0 range ensures correctness is the primary objective. Even with perfect retrieval (r_rel ≈ 0.8), an incorrect answer still gets negative reward.
|
| 169 |
+
|
| 170 |
+
### 2. Efficiency Matters
|
| 171 |
+
The non-linear penalty (-0.1, -0.25, -0.5) discourages excessive retrieval. The model learns to retrieve only when beneficial.
|
| 172 |
+
|
| 173 |
+
### 3. Gate Mechanism
|
| 174 |
+
R_rel is only active when the answer is correct. This prevents the model from learning to retrieve irrelevant but similar images just to get r_rel reward.
|
| 175 |
+
|
| 176 |
+
### 4. Timeout Prevention
|
| 177 |
+
The -2.0 death penalty strongly discourages infinite retrieval loops.
|
| 178 |
+
|
| 179 |
+
---
|
| 180 |
+
|
| 181 |
+
## Reward Range
|
| 182 |
+
|
| 183 |
+
**Theoretical Range**: [-3.0, +1.81]
|
| 184 |
+
|
| 185 |
+
- **Best case**: Correct answer with 1 highly relevant retrieval
|
| 186 |
+
- r_outcome = +1.0
|
| 187 |
+
- r_penalty = -0.1
|
| 188 |
+
- r_rel ≈ +1.0 (perfect similarity)
|
| 189 |
+
- R_total ≈ +1.9
|
| 190 |
+
|
| 191 |
+
- **Worst case**: Timeout
|
| 192 |
+
- r_outcome = -1.0
|
| 193 |
+
- r_penalty = -2.0
|
| 194 |
+
- r_rel = 0.0
|
| 195 |
+
- R_total = -3.0
|
| 196 |
+
|
| 197 |
+
**Typical Range**: [-1.5, +1.7]
|
| 198 |
+
|
| 199 |
+
---
|
| 200 |
+
|
| 201 |
+
## Implementation
|
| 202 |
+
|
| 203 |
+
See `reward_functions.py`:
|
| 204 |
+
- `compute_reward()`: Computes reward for a single trajectory
|
| 205 |
+
- `batch_compute_rewards()`: Batch processing for GRPO
|
| 206 |
+
- `scale_siglip_score()`: Scales SigLIP similarity scores
|
| 207 |
+
- `compute_f1_score()`: Token-level F1 for answer correctness
|
| 208 |
+
|
| 209 |
+
---
|
| 210 |
+
|
| 211 |
+
## Training Dynamics
|
| 212 |
+
|
| 213 |
+
### What the Model Learns
|
| 214 |
+
|
| 215 |
+
1. **Answer correctly first** (r_outcome = ±1.0 is dominant)
|
| 216 |
+
2. **Use retrieval strategically** (balance r_rel gain vs r_penalty cost)
|
| 217 |
+
3. **Retrieve relevant images** (maximize r_rel when retrieving)
|
| 218 |
+
4. **Avoid infinite loops** (timeout penalty is severe)
|
| 219 |
+
|
| 220 |
+
### Expected Behavior
|
| 221 |
+
|
| 222 |
+
- **Easy questions**: Direct answer (R ≈ +1.0)
|
| 223 |
+
- **Hard questions needing context**: 1-2 retrievals (R ≈ +1.4 to +1.6)
|
| 224 |
+
- **Ambiguous questions**: May try retrieval but learn to minimize steps
|
| 225 |
+
- **Impossible questions**: Learn to answer directly rather than waste steps
|
| 226 |
+
|
| 227 |
+
---
|
| 228 |
+
|
| 229 |
+
## Monitoring
|
| 230 |
+
|
| 231 |
+
Key metrics to track during training:
|
| 232 |
+
|
| 233 |
+
1. **Mean Reward**: Should increase from negative to positive
|
| 234 |
+
2. **Reward Std**: Should be > 0.5 (diversity in completions)
|
| 235 |
+
3. **frac_reward_zero_std**: Should be < 0.5 (not all completions getting same reward)
|
| 236 |
+
4. **Retrieval Rate**: Percentage of trajectories using <RET>
|
| 237 |
+
5. **Avg Steps**: Average number of retrievals per trajectory
|
| 238 |
+
6. **Timeout Rate**: Should be near 0%
|
| 239 |
+
|
| 240 |
+
---
|
| 241 |
+
|
| 242 |
+
## Hyperparameters
|
| 243 |
+
|
| 244 |
+
```python
|
| 245 |
+
# Reward function
|
| 246 |
+
F1_THRESHOLD = 0.5 # Threshold for correctness
|
| 247 |
+
MIN_SIGLIP = 0.33 # SigLIP score scaling min
|
| 248 |
+
MAX_SIGLIP = 0.97 # SigLIP score scaling max
|
| 249 |
+
MAX_TURNS = 3 # Maximum retrieval steps
|
| 250 |
+
TIMEOUT_PENALTY = -2.0 # Death penalty for timeout
|
| 251 |
+
|
| 252 |
+
# Penalty schedule
|
| 253 |
+
PENALTY_MAP = {
|
| 254 |
+
0: 0.0,
|
| 255 |
+
1: -0.1,
|
| 256 |
+
2: -0.25,
|
| 257 |
+
3: -0.5
|
| 258 |
+
}
|
| 259 |
+
```
|
ICL/RL/__pycache__/data_utils.cpython-311.pyc
ADDED
|
Binary file (7.81 kB). View file
|
|
|
ICL/RL/__pycache__/data_utils.cpython-313.pyc
ADDED
|
Binary file (6.51 kB). View file
|
|
|
ICL/RL/__pycache__/environment.cpython-311.pyc
ADDED
|
Binary file (11.4 kB). View file
|
|
|
ICL/RL/__pycache__/environment.cpython-313.pyc
ADDED
|
Binary file (9.56 kB). View file
|
|
|
ICL/RL/__pycache__/reward_functions.cpython-311.pyc
ADDED
|
Binary file (5.63 kB). View file
|
|
|
ICL/RL/__pycache__/reward_functions.cpython-313.pyc
ADDED
|
Binary file (4.81 kB). View file
|
|
|
ICL/RL/__pycache__/train_grpo.cpython-313.pyc
ADDED
|
Binary file (12.8 kB). View file
|
|
|
ICL/RL/build_rl_dataset.py
ADDED
|
@@ -0,0 +1,420 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
Build RL training dataset with 50:50 Positive/Negative split.
|
| 5 |
+
|
| 6 |
+
Positive Set (必须检索类):
|
| 7 |
+
- OK-VQA, A-OKVQA: 知识密集型问答
|
| 8 |
+
- ScienceQA: 科学常识推断
|
| 9 |
+
|
| 10 |
+
Negative Set (禁止/少检索类):
|
| 11 |
+
- TextVQA, OCR-VQA: 看图读字
|
| 12 |
+
- VQA-v2: 基础视觉识别
|
| 13 |
+
- CLEVR: 纯逻辑推理
|
| 14 |
+
|
| 15 |
+
Usage:
|
| 16 |
+
python build_rl_dataset.py \
|
| 17 |
+
--dataset-root /workspace/xiaobin/M3IT \
|
| 18 |
+
--output-dir /workspace/xiaobin/RL_data \
|
| 19 |
+
--positive-samples 10000 \
|
| 20 |
+
--negative-samples 10000
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
import argparse
|
| 24 |
+
import base64
|
| 25 |
+
import hashlib
|
| 26 |
+
import json
|
| 27 |
+
import os
|
| 28 |
+
import random
|
| 29 |
+
import re
|
| 30 |
+
from dataclasses import dataclass
|
| 31 |
+
from pathlib import Path
|
| 32 |
+
from typing import Dict, List, Optional
|
| 33 |
+
|
| 34 |
+
try:
|
| 35 |
+
from tqdm import tqdm
|
| 36 |
+
except ImportError:
|
| 37 |
+
tqdm = None
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
_B64_RE = re.compile(r"[^A-Za-z0-9+/=]")
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
@dataclass(frozen=True)
|
| 44 |
+
class QueryItem:
|
| 45 |
+
"""Query item for RL training."""
|
| 46 |
+
image_path: str
|
| 47 |
+
question: str
|
| 48 |
+
answer: str
|
| 49 |
+
subdir: str
|
| 50 |
+
uid: str
|
| 51 |
+
category: str # "positive" or "negative"
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def _looks_like_base64(s: str) -> bool:
|
| 55 |
+
if s.startswith("data:image"):
|
| 56 |
+
return True
|
| 57 |
+
if len(s) > 200 and all(c.isalnum() or c in "+/=\n\r" for c in s[:200]):
|
| 58 |
+
return True
|
| 59 |
+
return False
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def _b64_to_image_path(b64: str, cache_dir: Path, prefix: str) -> Optional[str]:
|
| 63 |
+
s = b64.strip()
|
| 64 |
+
if s.startswith("data:image"):
|
| 65 |
+
s = s.split(",", 1)[-1]
|
| 66 |
+
s = _B64_RE.sub("", s)
|
| 67 |
+
if not s:
|
| 68 |
+
return None
|
| 69 |
+
pad = len(s) % 4
|
| 70 |
+
if pad:
|
| 71 |
+
s += "=" * (4 - pad)
|
| 72 |
+
try:
|
| 73 |
+
data = base64.b64decode(s)
|
| 74 |
+
except Exception:
|
| 75 |
+
return None
|
| 76 |
+
sha = hashlib.sha1(data).hexdigest()
|
| 77 |
+
cache_dir.mkdir(parents=True, exist_ok=True)
|
| 78 |
+
out = cache_dir / f"{prefix}_{sha}.jpg"
|
| 79 |
+
if not out.exists():
|
| 80 |
+
with out.open("wb") as f:
|
| 81 |
+
f.write(data)
|
| 82 |
+
return str(out)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def _resolve_image_path(v: str, dataset_root: Path) -> Optional[str]:
|
| 86 |
+
p = Path(v)
|
| 87 |
+
if p.is_absolute():
|
| 88 |
+
return str(p)
|
| 89 |
+
cand = dataset_root / v
|
| 90 |
+
if cand.exists():
|
| 91 |
+
return str(cand)
|
| 92 |
+
return str(p)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def _extract_image_path(raw: Dict, dataset_root: Path, cache_dir: Path, prefix: str) -> Optional[str]:
|
| 96 |
+
keys = [
|
| 97 |
+
"image", "image_path", "image_file", "image_str", "base64",
|
| 98 |
+
"img", "img_path", "img_str", "image_base64_str", "image_base64",
|
| 99 |
+
"image_base64s", "images",
|
| 100 |
+
]
|
| 101 |
+
for k in keys:
|
| 102 |
+
if k not in raw:
|
| 103 |
+
continue
|
| 104 |
+
v = raw.get(k)
|
| 105 |
+
if isinstance(v, list) and v:
|
| 106 |
+
v = v[0]
|
| 107 |
+
if not isinstance(v, str) or not v.strip():
|
| 108 |
+
continue
|
| 109 |
+
v = v.strip()
|
| 110 |
+
if _looks_like_base64(v):
|
| 111 |
+
return _b64_to_image_path(v, cache_dir, prefix)
|
| 112 |
+
return _resolve_image_path(v, dataset_root)
|
| 113 |
+
return None
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def _extract_uid(raw: Dict, fallback: str) -> str:
|
| 117 |
+
for k in ("id", "image_id", "img_id", "question_id"):
|
| 118 |
+
v = raw.get(k)
|
| 119 |
+
if isinstance(v, (str, int)):
|
| 120 |
+
return str(v)
|
| 121 |
+
meta = raw.get("meta") if isinstance(raw.get("meta"), dict) else {}
|
| 122 |
+
for k in ("img_id", "id", "image_id"):
|
| 123 |
+
v = meta.get(k)
|
| 124 |
+
if isinstance(v, (str, int)):
|
| 125 |
+
return str(v)
|
| 126 |
+
return fallback
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def _extract_question(raw: Dict) -> Optional[str]:
|
| 130 |
+
for k in ("text", "question", "query", "prompt", "input"):
|
| 131 |
+
v = raw.get(k)
|
| 132 |
+
if isinstance(v, str) and v.strip():
|
| 133 |
+
return v.strip()
|
| 134 |
+
return None
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def _extract_answer(raw: Dict) -> Optional[str]:
|
| 138 |
+
if "answers" in raw:
|
| 139 |
+
v = raw.get("answers")
|
| 140 |
+
if isinstance(v, list):
|
| 141 |
+
for a in v:
|
| 142 |
+
if isinstance(a, str) and a.strip():
|
| 143 |
+
return a.strip()
|
| 144 |
+
for k in ("answer", "output", "label", "target", "paraphrased_answer", "original_answer"):
|
| 145 |
+
v = raw.get(k)
|
| 146 |
+
if isinstance(v, str) and v.strip():
|
| 147 |
+
return v.strip()
|
| 148 |
+
return None
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def discover_subdirs(dataset_root: Path, categories: List[str]) -> List[str]:
|
| 152 |
+
"""Discover all subdirs under dataset_root/data/{category}."""
|
| 153 |
+
out = []
|
| 154 |
+
for cat in categories:
|
| 155 |
+
base = dataset_root / "data" / cat
|
| 156 |
+
if not base.exists():
|
| 157 |
+
continue
|
| 158 |
+
for p in sorted(base.iterdir()):
|
| 159 |
+
if p.is_dir():
|
| 160 |
+
out.append(f"{cat}/{p.name}")
|
| 161 |
+
return out
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def find_split_file(subdir_dir: Path, split: str) -> Optional[Path]:
|
| 165 |
+
"""Find jsonl file for given split."""
|
| 166 |
+
if not subdir_dir.exists():
|
| 167 |
+
return None
|
| 168 |
+
split = split.lower()
|
| 169 |
+
files = sorted(p for p in subdir_dir.iterdir() if p.suffix == ".jsonl")
|
| 170 |
+
if not files:
|
| 171 |
+
return None
|
| 172 |
+
|
| 173 |
+
exact = [p for p in files if split in p.name.lower()]
|
| 174 |
+
if exact:
|
| 175 |
+
return exact[0]
|
| 176 |
+
|
| 177 |
+
if split in ("train", "training"):
|
| 178 |
+
for key in ("train", "training"):
|
| 179 |
+
cand = [p for p in files if key in p.name.lower()]
|
| 180 |
+
if cand:
|
| 181 |
+
return cand[0]
|
| 182 |
+
|
| 183 |
+
return files[0]
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def load_instructions(dataset_root: Path, subdir: str) -> List[str]:
|
| 187 |
+
"""Load instructions for a subdir."""
|
| 188 |
+
base = dataset_root / "data" / subdir
|
| 189 |
+
for name in ("instructions.json", "instruction.json"):
|
| 190 |
+
path = base / name
|
| 191 |
+
if not path.exists():
|
| 192 |
+
continue
|
| 193 |
+
with path.open("r", encoding="utf-8") as f:
|
| 194 |
+
try:
|
| 195 |
+
data = json.load(f)
|
| 196 |
+
except Exception:
|
| 197 |
+
return []
|
| 198 |
+
if isinstance(data, list):
|
| 199 |
+
return [str(x).strip() for x in data if str(x).strip()]
|
| 200 |
+
if isinstance(data, dict):
|
| 201 |
+
for key in ("instructions", "instruction", "prompts"):
|
| 202 |
+
v = data.get(key)
|
| 203 |
+
if isinstance(v, list):
|
| 204 |
+
return [str(x).strip() for x in v if str(x).strip()]
|
| 205 |
+
if isinstance(v, str) and v.strip():
|
| 206 |
+
return [v.strip()]
|
| 207 |
+
return []
|
| 208 |
+
return []
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def build_query_pool_for_subdir(
|
| 212 |
+
dataset_root: Path,
|
| 213 |
+
subdir: str,
|
| 214 |
+
split: str,
|
| 215 |
+
cache_dir: Path,
|
| 216 |
+
scan_limit: int,
|
| 217 |
+
category: str,
|
| 218 |
+
show_progress: bool,
|
| 219 |
+
) -> List[QueryItem]:
|
| 220 |
+
"""Build query pool for a single subdir."""
|
| 221 |
+
subdir_dir = dataset_root / "data" / subdir
|
| 222 |
+
jsonl_path = find_split_file(subdir_dir, split)
|
| 223 |
+
if jsonl_path is None:
|
| 224 |
+
return []
|
| 225 |
+
|
| 226 |
+
out: List[QueryItem] = []
|
| 227 |
+
prefix = subdir.replace("/", "_")
|
| 228 |
+
|
| 229 |
+
with jsonl_path.open("r", encoding="utf-8") as f:
|
| 230 |
+
lines = [line for line in f if line.strip()]
|
| 231 |
+
|
| 232 |
+
if show_progress and tqdm is not None:
|
| 233 |
+
lines = tqdm(lines, desc=f"load:{subdir}", unit="rec", mininterval=1.0)
|
| 234 |
+
|
| 235 |
+
for idx, line in enumerate(lines):
|
| 236 |
+
if scan_limit > 0 and idx >= scan_limit:
|
| 237 |
+
break
|
| 238 |
+
|
| 239 |
+
try:
|
| 240 |
+
raw = json.loads(line)
|
| 241 |
+
except Exception:
|
| 242 |
+
continue
|
| 243 |
+
|
| 244 |
+
question = _extract_question(raw)
|
| 245 |
+
answer = _extract_answer(raw)
|
| 246 |
+
if not question or not answer:
|
| 247 |
+
continue
|
| 248 |
+
|
| 249 |
+
img_path = _extract_image_path(raw, dataset_root, cache_dir, prefix)
|
| 250 |
+
if not img_path:
|
| 251 |
+
continue
|
| 252 |
+
|
| 253 |
+
uid = _extract_uid(raw, f"{subdir}:{idx:08d}")
|
| 254 |
+
|
| 255 |
+
out.append(QueryItem(
|
| 256 |
+
image_path=img_path,
|
| 257 |
+
question=question,
|
| 258 |
+
answer=answer,
|
| 259 |
+
subdir=subdir,
|
| 260 |
+
uid=uid,
|
| 261 |
+
category=category,
|
| 262 |
+
))
|
| 263 |
+
|
| 264 |
+
return out
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
def main() -> int:
|
| 268 |
+
ap = argparse.ArgumentParser(description="Build RL training dataset with 50:50 Positive/Negative split")
|
| 269 |
+
ap.add_argument("--dataset-root", default="/workspace/xiaobin/M3IT")
|
| 270 |
+
ap.add_argument("--output-dir", default="/workspace/xiaobin/RL_data")
|
| 271 |
+
ap.add_argument("--split", default="train")
|
| 272 |
+
ap.add_argument("--positive-samples", type=int, default=10000)
|
| 273 |
+
ap.add_argument("--negative-samples", type=int, default=10000)
|
| 274 |
+
ap.add_argument("--scan-limit", type=int, default=100000, help="Max records to scan per subdir")
|
| 275 |
+
ap.add_argument("--seed", type=int, default=42)
|
| 276 |
+
ap.add_argument("--no-progress", action="store_true")
|
| 277 |
+
args = ap.parse_args()
|
| 278 |
+
|
| 279 |
+
rng = random.Random(args.seed)
|
| 280 |
+
dataset_root = Path(args.dataset_root)
|
| 281 |
+
output_dir = Path(args.output_dir)
|
| 282 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 283 |
+
cache_dir = output_dir / "_image_cache"
|
| 284 |
+
cache_dir.mkdir(parents=True, exist_ok=True)
|
| 285 |
+
show_progress = not args.no_progress
|
| 286 |
+
|
| 287 |
+
# Define positive and negative subdirs
|
| 288 |
+
positive_subdirs_map = {
|
| 289 |
+
"vqa": ["okvqa", "a-okvqa", "science-qa"],
|
| 290 |
+
"reasoning": [], # Add if available
|
| 291 |
+
}
|
| 292 |
+
|
| 293 |
+
negative_subdirs_map = {
|
| 294 |
+
"vqa": ["text-vqa", "ocr-vqa", "vqa-v2", "clevr"],
|
| 295 |
+
}
|
| 296 |
+
|
| 297 |
+
print("[INFO] Discovering subdirs...")
|
| 298 |
+
|
| 299 |
+
# Collect positive subdirs
|
| 300 |
+
positive_subdirs = []
|
| 301 |
+
for cat, names in positive_subdirs_map.items():
|
| 302 |
+
base = dataset_root / "data" / cat
|
| 303 |
+
if not base.exists():
|
| 304 |
+
continue
|
| 305 |
+
for name in names:
|
| 306 |
+
subdir_path = base / name
|
| 307 |
+
if subdir_path.exists() and subdir_path.is_dir():
|
| 308 |
+
positive_subdirs.append(f"{cat}/{name}")
|
| 309 |
+
|
| 310 |
+
# Collect negative subdirs
|
| 311 |
+
negative_subdirs = []
|
| 312 |
+
for cat, names in negative_subdirs_map.items():
|
| 313 |
+
base = dataset_root / "data" / cat
|
| 314 |
+
if not base.exists():
|
| 315 |
+
continue
|
| 316 |
+
for name in names:
|
| 317 |
+
subdir_path = base / name
|
| 318 |
+
if subdir_path.exists() and subdir_path.is_dir():
|
| 319 |
+
negative_subdirs.append(f"{cat}/{name}")
|
| 320 |
+
|
| 321 |
+
print(f"[INFO] Found {len(positive_subdirs)} positive subdirs: {positive_subdirs}")
|
| 322 |
+
print(f"[INFO] Found {len(negative_subdirs)} negative subdirs: {negative_subdirs}")
|
| 323 |
+
|
| 324 |
+
if not positive_subdirs or not negative_subdirs:
|
| 325 |
+
print("[ERROR] Need both positive and negative subdirs")
|
| 326 |
+
return 1
|
| 327 |
+
|
| 328 |
+
# Build positive pool
|
| 329 |
+
print("\n[INFO] Building positive pool...")
|
| 330 |
+
positive_pool: List[QueryItem] = []
|
| 331 |
+
for sd in positive_subdirs:
|
| 332 |
+
queries = build_query_pool_for_subdir(
|
| 333 |
+
dataset_root, sd, args.split, cache_dir,
|
| 334 |
+
args.scan_limit, "positive", show_progress
|
| 335 |
+
)
|
| 336 |
+
positive_pool.extend(queries)
|
| 337 |
+
print(f" [{sd}] Loaded {len(queries)} queries")
|
| 338 |
+
|
| 339 |
+
print(f"[INFO] Total positive pool: {len(positive_pool)}")
|
| 340 |
+
|
| 341 |
+
# Build negative pool
|
| 342 |
+
print("\n[INFO] Building negative pool...")
|
| 343 |
+
negative_pool: List[QueryItem] = []
|
| 344 |
+
for sd in negative_subdirs:
|
| 345 |
+
queries = build_query_pool_for_subdir(
|
| 346 |
+
dataset_root, sd, args.split, cache_dir,
|
| 347 |
+
args.scan_limit, "negative", show_progress
|
| 348 |
+
)
|
| 349 |
+
negative_pool.extend(queries)
|
| 350 |
+
print(f" [{sd}] Loaded {len(queries)} queries")
|
| 351 |
+
|
| 352 |
+
print(f"[INFO] Total negative pool: {len(negative_pool)}")
|
| 353 |
+
|
| 354 |
+
# Sample
|
| 355 |
+
if len(positive_pool) < args.positive_samples:
|
| 356 |
+
print(f"[WARN] Only {len(positive_pool)} positive samples available, using all")
|
| 357 |
+
positive_samples = positive_pool
|
| 358 |
+
else:
|
| 359 |
+
positive_samples = rng.sample(positive_pool, args.positive_samples)
|
| 360 |
+
|
| 361 |
+
if len(negative_pool) < args.negative_samples:
|
| 362 |
+
print(f"[WARN] Only {len(negative_pool)} negative samples available, using all")
|
| 363 |
+
negative_samples = negative_pool
|
| 364 |
+
else:
|
| 365 |
+
negative_samples = rng.sample(negative_pool, args.negative_samples)
|
| 366 |
+
|
| 367 |
+
# Combine and shuffle
|
| 368 |
+
all_samples = positive_samples + negative_samples
|
| 369 |
+
rng.shuffle(all_samples)
|
| 370 |
+
|
| 371 |
+
print(f"\n[INFO] Final dataset: {len(positive_samples)} positive + {len(negative_samples)} negative = {len(all_samples)} total")
|
| 372 |
+
|
| 373 |
+
# Load instructions for each subdir
|
| 374 |
+
inst_map: Dict[str, List[str]] = {}
|
| 375 |
+
all_subdirs = set(q.subdir for q in all_samples)
|
| 376 |
+
for sd in all_subdirs:
|
| 377 |
+
inst_map[sd] = load_instructions(dataset_root, sd)
|
| 378 |
+
|
| 379 |
+
# Write to jsonl
|
| 380 |
+
output_file = output_dir / "rl_train.jsonl"
|
| 381 |
+
with output_file.open("w", encoding="utf-8") as f:
|
| 382 |
+
for q in all_samples:
|
| 383 |
+
instructions = inst_map.get(q.subdir, [])
|
| 384 |
+
instruction = rng.choice(instructions) if instructions else "Please answer the question based on the image."
|
| 385 |
+
|
| 386 |
+
rec = {
|
| 387 |
+
"id": q.uid,
|
| 388 |
+
"category": q.category,
|
| 389 |
+
"subdir": q.subdir,
|
| 390 |
+
"image": q.image_path,
|
| 391 |
+
"question": q.question,
|
| 392 |
+
"answer": q.answer,
|
| 393 |
+
"instruction": instruction,
|
| 394 |
+
}
|
| 395 |
+
f.write(json.dumps(rec, ensure_ascii=False) + "\n")
|
| 396 |
+
|
| 397 |
+
print(f"\n[INFO] Saved to {output_file}")
|
| 398 |
+
|
| 399 |
+
# Write statistics
|
| 400 |
+
stats_file = output_dir / "dataset_stats.json"
|
| 401 |
+
stats = {
|
| 402 |
+
"total_samples": len(all_samples),
|
| 403 |
+
"positive_samples": len(positive_samples),
|
| 404 |
+
"negative_samples": len(negative_samples),
|
| 405 |
+
"positive_subdirs": positive_subdirs,
|
| 406 |
+
"negative_subdirs": negative_subdirs,
|
| 407 |
+
"positive_distribution": {sd: sum(1 for q in positive_samples if q.subdir == sd) for sd in positive_subdirs},
|
| 408 |
+
"negative_distribution": {sd: sum(1 for q in negative_samples if q.subdir == sd) for sd in negative_subdirs},
|
| 409 |
+
}
|
| 410 |
+
with stats_file.open("w", encoding="utf-8") as f:
|
| 411 |
+
json.dump(stats, f, indent=2, ensure_ascii=False)
|
| 412 |
+
|
| 413 |
+
print(f"[INFO] Saved statistics to {stats_file}")
|
| 414 |
+
print("\n[DONE] Dataset construction complete!")
|
| 415 |
+
|
| 416 |
+
return 0
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
if __name__ == "__main__":
|
| 420 |
+
raise SystemExit(main())
|
ICL/RL/config.yaml
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# GRPO Training Configuration
|
| 2 |
+
|
| 3 |
+
# Model paths
|
| 4 |
+
model:
|
| 5 |
+
vlm_path: "/workspace/xiaobin/SFT_model/hf_qwen3vl_siglip_vqa_iter_0000881"
|
| 6 |
+
siglip_path: "/workspace/siglip2-so400m-patch16-naflex"
|
| 7 |
+
output_dir: "/workspace/xiaobin/RL/output"
|
| 8 |
+
|
| 9 |
+
# Training hyperparameters
|
| 10 |
+
training:
|
| 11 |
+
learning_rate: 1.0e-5
|
| 12 |
+
per_device_train_batch_size: 1
|
| 13 |
+
gradient_accumulation_steps: 4
|
| 14 |
+
num_train_epochs: 3
|
| 15 |
+
warmup_steps: 100
|
| 16 |
+
max_grad_norm: 1.0
|
| 17 |
+
bf16: true
|
| 18 |
+
|
| 19 |
+
# GRPO specific
|
| 20 |
+
grpo:
|
| 21 |
+
num_generations: 8 # Group size for relative optimization
|
| 22 |
+
max_prompt_length: 512
|
| 23 |
+
max_completion_length: 1024
|
| 24 |
+
temperature: 0.7
|
| 25 |
+
do_sample: true
|
| 26 |
+
|
| 27 |
+
# Environment
|
| 28 |
+
environment:
|
| 29 |
+
max_turns: 3 # Maximum retrieval turns before timeout
|
| 30 |
+
self_exclusion: true # Exclude query image from retrieval
|
| 31 |
+
|
| 32 |
+
# Reward function
|
| 33 |
+
reward:
|
| 34 |
+
# R_outcome
|
| 35 |
+
correct_reward: 1.0
|
| 36 |
+
wrong_reward: -1.0
|
| 37 |
+
|
| 38 |
+
# R_penalty (step-based)
|
| 39 |
+
penalty_0_shot: 0.0
|
| 40 |
+
penalty_1_shot: -0.1
|
| 41 |
+
penalty_2_shot: -0.25
|
| 42 |
+
penalty_3_shot: -0.5
|
| 43 |
+
penalty_timeout: -2.0
|
| 44 |
+
|
| 45 |
+
# R_rel (SigLIP similarity scaling)
|
| 46 |
+
siglip_min_threshold: 0.33 # P05 of same-subdir pairs
|
| 47 |
+
siglip_max_threshold: 0.97 # P95 of same-subdir pairs
|
| 48 |
+
|
| 49 |
+
# Data
|
| 50 |
+
data:
|
| 51 |
+
# Positive datasets (needs retrieval)
|
| 52 |
+
positive_datasets:
|
| 53 |
+
- "okvqa"
|
| 54 |
+
- "aokvqa"
|
| 55 |
+
- "science_qa"
|
| 56 |
+
|
| 57 |
+
# Negative datasets (direct answer)
|
| 58 |
+
negative_datasets:
|
| 59 |
+
- "text_vqa"
|
| 60 |
+
- "ocr_vqa"
|
| 61 |
+
- "vqa_v2"
|
| 62 |
+
- "clevr"
|
| 63 |
+
|
| 64 |
+
# Sampling
|
| 65 |
+
max_samples_per_dataset: null # null for full dataset, or set a number for debugging
|
| 66 |
+
shuffle: true
|
| 67 |
+
seed: 42
|
| 68 |
+
|
| 69 |
+
# Logging
|
| 70 |
+
logging:
|
| 71 |
+
use_wandb: true
|
| 72 |
+
wandb_project: "grpo-retrieval-vqa"
|
| 73 |
+
logging_steps: 10
|
| 74 |
+
save_steps: 500
|
| 75 |
+
save_total_limit: 3
|
| 76 |
+
|
| 77 |
+
# Hardware
|
| 78 |
+
hardware:
|
| 79 |
+
device: "cuda"
|
| 80 |
+
num_gpus: 8
|
| 81 |
+
tensor_parallel: 8
|
| 82 |
+
data_parallel: 1
|
ICL/RL/data_utils.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Data utilities for loading and preprocessing M3IT dataset.
|
| 3 |
+
Splits into 50% positive (needs retrieval) and 50% negative (direct answer).
|
| 4 |
+
"""
|
| 5 |
+
from datasets import load_dataset, concatenate_datasets
|
| 6 |
+
from typing import Dict, List, Tuple
|
| 7 |
+
import random
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
# Dataset categorization based on plan
|
| 11 |
+
POSITIVE_DATASETS = [
|
| 12 |
+
"okvqa", # Knowledge-intensive VQA
|
| 13 |
+
"aokvqa", # A-OKVQA
|
| 14 |
+
"science_qa", # Science reasoning
|
| 15 |
+
# "viquae", # Entity knowledge (if available)
|
| 16 |
+
]
|
| 17 |
+
|
| 18 |
+
NEGATIVE_DATASETS = [
|
| 19 |
+
"text_vqa", # Read text from image
|
| 20 |
+
"ocr_vqa", # OCR-based VQA
|
| 21 |
+
"vqa_v2", # Basic visual recognition
|
| 22 |
+
"clevr", # Pure logic reasoning
|
| 23 |
+
]
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def load_m3it_splits(
|
| 27 |
+
positive_datasets: List[str] = POSITIVE_DATASETS,
|
| 28 |
+
negative_datasets: List[str] = NEGATIVE_DATASETS,
|
| 29 |
+
split: str = "train",
|
| 30 |
+
max_samples_per_dataset: int = None
|
| 31 |
+
) -> Tuple[List[Dict], List[Dict]]:
|
| 32 |
+
"""
|
| 33 |
+
Load M3IT dataset and split into positive/negative sets.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
positive_datasets: List of dataset names that need retrieval
|
| 37 |
+
negative_datasets: List of dataset names that don't need retrieval
|
| 38 |
+
split: Dataset split (train/validation/test)
|
| 39 |
+
max_samples_per_dataset: Maximum samples per dataset (for debugging)
|
| 40 |
+
|
| 41 |
+
Returns:
|
| 42 |
+
(positive_samples, negative_samples)
|
| 43 |
+
"""
|
| 44 |
+
positive_samples = []
|
| 45 |
+
negative_samples = []
|
| 46 |
+
|
| 47 |
+
print(f"Loading M3IT datasets ({split} split)...")
|
| 48 |
+
|
| 49 |
+
# Load positive datasets
|
| 50 |
+
for dataset_name in positive_datasets:
|
| 51 |
+
try:
|
| 52 |
+
ds = load_dataset("MMInstruction/M3IT", dataset_name, split=split)
|
| 53 |
+
if max_samples_per_dataset:
|
| 54 |
+
ds = ds.select(range(min(len(ds), max_samples_per_dataset)))
|
| 55 |
+
|
| 56 |
+
samples = [
|
| 57 |
+
{
|
| 58 |
+
"image": sample["image"],
|
| 59 |
+
"question": sample["instruction"],
|
| 60 |
+
"answer": sample["outputs"],
|
| 61 |
+
"dataset": dataset_name,
|
| 62 |
+
"needs_retrieval": True,
|
| 63 |
+
"id": f"{dataset_name}_{i}"
|
| 64 |
+
}
|
| 65 |
+
for i, sample in enumerate(ds)
|
| 66 |
+
]
|
| 67 |
+
positive_samples.extend(samples)
|
| 68 |
+
print(f" ✓ {dataset_name}: {len(samples)} samples (positive)")
|
| 69 |
+
except Exception as e:
|
| 70 |
+
print(f" ✗ {dataset_name}: Failed to load - {e}")
|
| 71 |
+
|
| 72 |
+
# Load negative datasets
|
| 73 |
+
for dataset_name in negative_datasets:
|
| 74 |
+
try:
|
| 75 |
+
ds = load_dataset("MMInstruction/M3IT", dataset_name, split=split)
|
| 76 |
+
if max_samples_per_dataset:
|
| 77 |
+
ds = ds.select(range(min(len(ds), max_samples_per_dataset)))
|
| 78 |
+
|
| 79 |
+
samples = [
|
| 80 |
+
{
|
| 81 |
+
"image": sample["image"],
|
| 82 |
+
"question": sample["instruction"],
|
| 83 |
+
"answer": sample["outputs"],
|
| 84 |
+
"dataset": dataset_name,
|
| 85 |
+
"needs_retrieval": False,
|
| 86 |
+
"id": f"{dataset_name}_{i}"
|
| 87 |
+
}
|
| 88 |
+
for i, sample in enumerate(ds)
|
| 89 |
+
]
|
| 90 |
+
negative_samples.extend(samples)
|
| 91 |
+
print(f" ✓ {dataset_name}: {len(samples)} samples (negative)")
|
| 92 |
+
except Exception as e:
|
| 93 |
+
print(f" ✗ {dataset_name}: Failed to load - {e}")
|
| 94 |
+
|
| 95 |
+
print(f"\nTotal: {len(positive_samples)} positive, {len(negative_samples)} negative")
|
| 96 |
+
return positive_samples, negative_samples
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def create_balanced_dataset(
|
| 100 |
+
positive_samples: List[Dict],
|
| 101 |
+
negative_samples: List[Dict],
|
| 102 |
+
shuffle: bool = True,
|
| 103 |
+
seed: int = 42
|
| 104 |
+
) -> List[Dict]:
|
| 105 |
+
"""
|
| 106 |
+
Create balanced 50:50 dataset by sampling.
|
| 107 |
+
|
| 108 |
+
Args:
|
| 109 |
+
positive_samples: Positive samples (needs retrieval)
|
| 110 |
+
negative_samples: Negative samples (direct answer)
|
| 111 |
+
shuffle: Whether to shuffle the combined dataset
|
| 112 |
+
seed: Random seed
|
| 113 |
+
|
| 114 |
+
Returns:
|
| 115 |
+
Balanced dataset
|
| 116 |
+
"""
|
| 117 |
+
# Balance to 50:50
|
| 118 |
+
min_size = min(len(positive_samples), len(negative_samples))
|
| 119 |
+
|
| 120 |
+
random.seed(seed)
|
| 121 |
+
balanced_positive = random.sample(positive_samples, min_size)
|
| 122 |
+
balanced_negative = random.sample(negative_samples, min_size)
|
| 123 |
+
|
| 124 |
+
# Combine
|
| 125 |
+
balanced_dataset = balanced_positive + balanced_negative
|
| 126 |
+
|
| 127 |
+
if shuffle:
|
| 128 |
+
random.shuffle(balanced_dataset)
|
| 129 |
+
|
| 130 |
+
print(f"Created balanced dataset: {len(balanced_dataset)} samples (50% pos, 50% neg)")
|
| 131 |
+
return balanced_dataset
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def build_candidate_pools(
|
| 135 |
+
positive_samples: List[Dict],
|
| 136 |
+
negative_samples: List[Dict]
|
| 137 |
+
) -> Dict[str, List[Dict]]:
|
| 138 |
+
"""
|
| 139 |
+
Build candidate pools for retrieval, organized by dataset.
|
| 140 |
+
Each pool contains all images from that dataset for sub-dataset retrieval.
|
| 141 |
+
|
| 142 |
+
Args:
|
| 143 |
+
positive_samples: Positive samples
|
| 144 |
+
negative_samples: Negative samples
|
| 145 |
+
|
| 146 |
+
Returns:
|
| 147 |
+
Dictionary mapping dataset name to list of candidates
|
| 148 |
+
"""
|
| 149 |
+
all_samples = positive_samples + negative_samples
|
| 150 |
+
pools = {}
|
| 151 |
+
|
| 152 |
+
for sample in all_samples:
|
| 153 |
+
dataset_name = sample["dataset"]
|
| 154 |
+
if dataset_name not in pools:
|
| 155 |
+
pools[dataset_name] = []
|
| 156 |
+
|
| 157 |
+
pools[dataset_name].append({
|
| 158 |
+
"image": sample["image"],
|
| 159 |
+
"caption": sample["answer"], # Use answer as caption
|
| 160 |
+
"id": sample["id"]
|
| 161 |
+
})
|
| 162 |
+
|
| 163 |
+
print(f"\nBuilt candidate pools for {len(pools)} datasets:")
|
| 164 |
+
for name, pool in pools.items():
|
| 165 |
+
print(f" {name}: {len(pool)} candidates")
|
| 166 |
+
|
| 167 |
+
return pools
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def format_prompt(sample: Dict, include_system_prompt: bool = True) -> str:
|
| 171 |
+
"""
|
| 172 |
+
Format a sample into a prompt for the model.
|
| 173 |
+
|
| 174 |
+
Args:
|
| 175 |
+
sample: Sample dictionary
|
| 176 |
+
include_system_prompt: Whether to include system instructions
|
| 177 |
+
|
| 178 |
+
Returns:
|
| 179 |
+
Formatted prompt string
|
| 180 |
+
"""
|
| 181 |
+
system_prompt = ""
|
| 182 |
+
if include_system_prompt:
|
| 183 |
+
system_prompt = (
|
| 184 |
+
"You are a visual question answering assistant. "
|
| 185 |
+
"You can either answer directly using <ANS> or retrieve similar images using <RET>.\n"
|
| 186 |
+
"- Use <RET> followed by a description when you need to see similar examples.\n"
|
| 187 |
+
"- Use <ANS> followed by your answer when you're ready to answer.\n\n"
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
user_prompt = f"Question: {sample['question']}\n"
|
| 191 |
+
|
| 192 |
+
return system_prompt + user_prompt
|
ICL/RL/environment.py
ADDED
|
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Environment for GRPO rollout with SigLIP-based retrieval.
|
| 3 |
+
Implements the while-loop interaction: <RET> -> retrieve top-1 -> inject -> continue
|
| 4 |
+
"""
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from typing import List, Dict, Any, Tuple, Optional
|
| 8 |
+
from PIL import Image
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class RetrievalEnvironment:
|
| 13 |
+
"""
|
| 14 |
+
Environment that handles the retrieval loop during GRPO rollout.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
siglip_model,
|
| 20 |
+
siglip_processor,
|
| 21 |
+
candidate_pools: Dict[str, List[Dict]],
|
| 22 |
+
max_turns: int = 3,
|
| 23 |
+
device: str = "cuda"
|
| 24 |
+
):
|
| 25 |
+
"""
|
| 26 |
+
Args:
|
| 27 |
+
siglip_model: SigLIP vision-language model
|
| 28 |
+
siglip_processor: SigLIP processor
|
| 29 |
+
candidate_pools: Dict mapping dataset name to list of candidate images
|
| 30 |
+
Each candidate: {"image": PIL.Image, "caption": str, "id": str}
|
| 31 |
+
max_turns: Maximum retrieval turns before timeout
|
| 32 |
+
device: Device for computation
|
| 33 |
+
"""
|
| 34 |
+
self.siglip_model = siglip_model
|
| 35 |
+
self.siglip_processor = siglip_processor
|
| 36 |
+
self.candidate_pools = candidate_pools
|
| 37 |
+
self.max_turns = max_turns
|
| 38 |
+
self.device = device
|
| 39 |
+
|
| 40 |
+
# Precompute image embeddings for each pool
|
| 41 |
+
self.pool_embeddings = {}
|
| 42 |
+
self._precompute_embeddings()
|
| 43 |
+
|
| 44 |
+
def _precompute_embeddings(self):
|
| 45 |
+
"""Precompute SigLIP embeddings for all candidate images."""
|
| 46 |
+
print("Precomputing SigLIP embeddings for candidate pools...")
|
| 47 |
+
self.siglip_model.eval()
|
| 48 |
+
|
| 49 |
+
# Get the device of the first parameter of siglip_model
|
| 50 |
+
model_device = next(self.siglip_model.parameters()).device
|
| 51 |
+
|
| 52 |
+
for dataset_name, candidates in self.candidate_pools.items():
|
| 53 |
+
images = [c["image"] for c in candidates]
|
| 54 |
+
embeddings = []
|
| 55 |
+
|
| 56 |
+
# Process in batches
|
| 57 |
+
batch_size = 32
|
| 58 |
+
for i in range(0, len(images), batch_size):
|
| 59 |
+
batch_images = images[i:i + batch_size]
|
| 60 |
+
inputs = self.siglip_processor(
|
| 61 |
+
images=batch_images,
|
| 62 |
+
return_tensors="pt",
|
| 63 |
+
padding=True
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
# Move inputs to the same device as siglip_model
|
| 67 |
+
inputs = {k: v.to(model_device) for k, v in inputs.items()}
|
| 68 |
+
|
| 69 |
+
with torch.no_grad():
|
| 70 |
+
image_embeds = self.siglip_model.get_image_features(**inputs)
|
| 71 |
+
# Extract tensor from output object if needed
|
| 72 |
+
if hasattr(image_embeds, 'pooler_output'):
|
| 73 |
+
image_embeds = image_embeds.pooler_output
|
| 74 |
+
elif not isinstance(image_embeds, torch.Tensor):
|
| 75 |
+
image_embeds = image_embeds[0]
|
| 76 |
+
image_embeds = F.normalize(image_embeds, p=2, dim=-1)
|
| 77 |
+
embeddings.append(image_embeds.cpu())
|
| 78 |
+
|
| 79 |
+
self.pool_embeddings[dataset_name] = torch.cat(embeddings, dim=0)
|
| 80 |
+
print(f" {dataset_name}: {len(candidates)} images")
|
| 81 |
+
|
| 82 |
+
def retrieve_top1(
|
| 83 |
+
self,
|
| 84 |
+
query_text: str,
|
| 85 |
+
dataset_name: str,
|
| 86 |
+
exclude_id: str
|
| 87 |
+
) -> Tuple[Dict, float]:
|
| 88 |
+
"""
|
| 89 |
+
Retrieve top-1 most similar image based on query text.
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
query_text: Description text from model's <RET> output
|
| 93 |
+
dataset_name: Which candidate pool to search
|
| 94 |
+
exclude_id: Image ID to exclude (the query image itself)
|
| 95 |
+
|
| 96 |
+
Returns:
|
| 97 |
+
(retrieved_candidate, similarity_score)
|
| 98 |
+
"""
|
| 99 |
+
# Encode query text
|
| 100 |
+
text_inputs = self.siglip_processor(
|
| 101 |
+
text=[query_text],
|
| 102 |
+
return_tensors="pt",
|
| 103 |
+
padding=True
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
# Move inputs to the same device as siglip_model
|
| 107 |
+
# Get the device of the first parameter of siglip_model
|
| 108 |
+
model_device = next(self.siglip_model.parameters()).device
|
| 109 |
+
text_inputs = {k: v.to(model_device) for k, v in text_inputs.items()}
|
| 110 |
+
|
| 111 |
+
with torch.no_grad():
|
| 112 |
+
text_embeds = self.siglip_model.get_text_features(**text_inputs)
|
| 113 |
+
# Extract tensor from output object if needed
|
| 114 |
+
if hasattr(text_embeds, 'pooler_output'):
|
| 115 |
+
text_embeds = text_embeds.pooler_output
|
| 116 |
+
elif not isinstance(text_embeds, torch.Tensor):
|
| 117 |
+
text_embeds = text_embeds[0]
|
| 118 |
+
text_embeds = F.normalize(text_embeds, p=2, dim=-1)
|
| 119 |
+
|
| 120 |
+
# Compute similarities - move image_embeds to same device as text_embeds
|
| 121 |
+
image_embeds = self.pool_embeddings[dataset_name].to(text_embeds.device)
|
| 122 |
+
similarities = (text_embeds @ image_embeds.T).squeeze(0)
|
| 123 |
+
|
| 124 |
+
# Get candidates
|
| 125 |
+
candidates = self.candidate_pools[dataset_name]
|
| 126 |
+
|
| 127 |
+
# Exclude query image
|
| 128 |
+
valid_indices = [i for i, c in enumerate(candidates) if c["id"] != exclude_id]
|
| 129 |
+
valid_similarities = similarities[valid_indices]
|
| 130 |
+
|
| 131 |
+
# Get top-1
|
| 132 |
+
top_idx = valid_similarities.argmax().item()
|
| 133 |
+
actual_idx = valid_indices[top_idx]
|
| 134 |
+
top_score = valid_similarities[top_idx].item()
|
| 135 |
+
|
| 136 |
+
return candidates[actual_idx], top_score
|
| 137 |
+
|
| 138 |
+
def compute_image_similarity(
|
| 139 |
+
self,
|
| 140 |
+
query_image: Image.Image,
|
| 141 |
+
retrieved_image: Image.Image
|
| 142 |
+
) -> float:
|
| 143 |
+
"""
|
| 144 |
+
Compute SigLIP image-to-image similarity for R_rel reward.
|
| 145 |
+
"""
|
| 146 |
+
inputs = self.siglip_processor(
|
| 147 |
+
images=[query_image, retrieved_image],
|
| 148 |
+
return_tensors="pt",
|
| 149 |
+
padding=True
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
# Move inputs to the same device as siglip_model
|
| 153 |
+
model_device = next(self.siglip_model.parameters()).device
|
| 154 |
+
inputs = {k: v.to(model_device) for k, v in inputs.items()}
|
| 155 |
+
|
| 156 |
+
with torch.no_grad():
|
| 157 |
+
image_embeds = self.siglip_model.get_image_features(**inputs)
|
| 158 |
+
# Extract tensor from output object if needed
|
| 159 |
+
if hasattr(image_embeds, 'pooler_output'):
|
| 160 |
+
image_embeds = image_embeds.pooler_output
|
| 161 |
+
elif not isinstance(image_embeds, torch.Tensor):
|
| 162 |
+
image_embeds = image_embeds[0]
|
| 163 |
+
image_embeds = F.normalize(image_embeds, p=2, dim=-1)
|
| 164 |
+
|
| 165 |
+
similarity = (image_embeds[0] @ image_embeds[1]).item()
|
| 166 |
+
return similarity
|
| 167 |
+
|
| 168 |
+
def rollout(
|
| 169 |
+
self,
|
| 170 |
+
model,
|
| 171 |
+
tokenizer,
|
| 172 |
+
initial_prompt: str,
|
| 173 |
+
query_image: Image.Image,
|
| 174 |
+
query_id: str,
|
| 175 |
+
dataset_name: str,
|
| 176 |
+
generation_kwargs: Dict[str, Any]
|
| 177 |
+
) -> Tuple[str, List[float], bool]:
|
| 178 |
+
"""
|
| 179 |
+
Execute one rollout trajectory with retrieval loop.
|
| 180 |
+
|
| 181 |
+
Args:
|
| 182 |
+
model: VLM model
|
| 183 |
+
tokenizer: Tokenizer
|
| 184 |
+
initial_prompt: Initial prompt with image and question
|
| 185 |
+
query_image: Query image
|
| 186 |
+
query_id: Query image ID for self-exclusion
|
| 187 |
+
dataset_name: Dataset name for candidate pool
|
| 188 |
+
generation_kwargs: Generation parameters
|
| 189 |
+
|
| 190 |
+
Returns:
|
| 191 |
+
(full_trajectory, siglip_scores, timeout)
|
| 192 |
+
"""
|
| 193 |
+
conversation_history = initial_prompt
|
| 194 |
+
siglip_scores = []
|
| 195 |
+
timeout = False
|
| 196 |
+
|
| 197 |
+
for turn in range(self.max_turns):
|
| 198 |
+
# Generate model output
|
| 199 |
+
output = model.generate(
|
| 200 |
+
conversation_history,
|
| 201 |
+
**generation_kwargs
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
# Check if model outputs <ANS>
|
| 205 |
+
if "<ANS>" in output:
|
| 206 |
+
conversation_history += output
|
| 207 |
+
break
|
| 208 |
+
|
| 209 |
+
# Check if model outputs <RET>
|
| 210 |
+
if "<RET>" in output:
|
| 211 |
+
# Extract description text after <RET>
|
| 212 |
+
description = output.split("<RET>")[-1].strip()
|
| 213 |
+
|
| 214 |
+
# Retrieve top-1 similar image
|
| 215 |
+
retrieved_candidate, _ = self.retrieve_top1(
|
| 216 |
+
query_text=description,
|
| 217 |
+
dataset_name=dataset_name,
|
| 218 |
+
exclude_id=query_id
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
# Compute image-to-image similarity for reward
|
| 222 |
+
img_similarity = self.compute_image_similarity(
|
| 223 |
+
query_image,
|
| 224 |
+
retrieved_candidate["image"]
|
| 225 |
+
)
|
| 226 |
+
siglip_scores.append(img_similarity)
|
| 227 |
+
|
| 228 |
+
# Inject retrieved image into conversation
|
| 229 |
+
system_response = f"\nSystem: Retrieved Image - {retrieved_candidate['caption']}\n"
|
| 230 |
+
conversation_history += output + system_response
|
| 231 |
+
else:
|
| 232 |
+
# Model output neither <RET> nor <ANS>, treat as direct answer
|
| 233 |
+
conversation_history += output
|
| 234 |
+
break
|
| 235 |
+
|
| 236 |
+
# Check timeout
|
| 237 |
+
if turn == self.max_turns - 1 and "<ANS>" not in conversation_history:
|
| 238 |
+
timeout = True
|
| 239 |
+
|
| 240 |
+
return conversation_history, siglip_scores, timeout
|
ICL/RL/inference_example.py
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Example inference script for the trained GRPO model.
|
| 3 |
+
Shows how to use the model for retrieval-augmented VQA.
|
| 4 |
+
"""
|
| 5 |
+
import torch
|
| 6 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor
|
| 7 |
+
from PIL import Image
|
| 8 |
+
from environment import RetrievalEnvironment
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def load_trained_model(model_path: str, device: str = "cuda"):
|
| 12 |
+
"""Load the trained GRPO model."""
|
| 13 |
+
print(f"Loading model from {model_path}...")
|
| 14 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 15 |
+
model_path,
|
| 16 |
+
torch_dtype=torch.bfloat16,
|
| 17 |
+
device_map="auto",
|
| 18 |
+
trust_remote_code=True
|
| 19 |
+
)
|
| 20 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
| 21 |
+
return model, tokenizer
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def inference_with_retrieval(
|
| 25 |
+
model,
|
| 26 |
+
tokenizer,
|
| 27 |
+
environment: RetrievalEnvironment,
|
| 28 |
+
image_path: str,
|
| 29 |
+
question: str,
|
| 30 |
+
dataset_name: str,
|
| 31 |
+
image_id: str,
|
| 32 |
+
max_new_tokens: int = 512
|
| 33 |
+
):
|
| 34 |
+
"""
|
| 35 |
+
Run inference with retrieval capability.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
model: Trained VLM model
|
| 39 |
+
tokenizer: Tokenizer
|
| 40 |
+
environment: Retrieval environment
|
| 41 |
+
image_path: Path to query image
|
| 42 |
+
question: Question to answer
|
| 43 |
+
dataset_name: Dataset name for candidate pool
|
| 44 |
+
image_id: Image ID for self-exclusion
|
| 45 |
+
max_new_tokens: Maximum tokens to generate
|
| 46 |
+
|
| 47 |
+
Returns:
|
| 48 |
+
Final answer and retrieval history
|
| 49 |
+
"""
|
| 50 |
+
# Load image
|
| 51 |
+
image = Image.open(image_path).convert("RGB")
|
| 52 |
+
|
| 53 |
+
# Format prompt
|
| 54 |
+
prompt = (
|
| 55 |
+
"You are a visual question answering assistant. "
|
| 56 |
+
"You can either answer directly using <ANS> or retrieve similar images using <RET>.\n"
|
| 57 |
+
f"Question: {question}\n"
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
# Execute rollout
|
| 61 |
+
trajectory, siglip_scores, timeout = environment.rollout(
|
| 62 |
+
model=model,
|
| 63 |
+
tokenizer=tokenizer,
|
| 64 |
+
initial_prompt=prompt,
|
| 65 |
+
query_image=image,
|
| 66 |
+
query_id=image_id,
|
| 67 |
+
dataset_name=dataset_name,
|
| 68 |
+
generation_kwargs={
|
| 69 |
+
"max_new_tokens": max_new_tokens,
|
| 70 |
+
"temperature": 0.7,
|
| 71 |
+
"do_sample": True,
|
| 72 |
+
}
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
# Extract final answer
|
| 76 |
+
if "<ANS>" in trajectory:
|
| 77 |
+
answer = trajectory.split("<ANS>")[-1].strip()
|
| 78 |
+
else:
|
| 79 |
+
answer = "No answer generated (timeout)"
|
| 80 |
+
|
| 81 |
+
# Count retrieval steps
|
| 82 |
+
num_retrievals = trajectory.count("<RET>")
|
| 83 |
+
|
| 84 |
+
return {
|
| 85 |
+
"answer": answer,
|
| 86 |
+
"num_retrievals": num_retrievals,
|
| 87 |
+
"siglip_scores": siglip_scores,
|
| 88 |
+
"timeout": timeout,
|
| 89 |
+
"full_trajectory": trajectory
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def main():
|
| 94 |
+
"""Example usage."""
|
| 95 |
+
# Configuration
|
| 96 |
+
MODEL_PATH = "/workspace/xiaobin/RL/output/final_model"
|
| 97 |
+
SIGLIP_MODEL = "/workspace/siglip2-so400m-patch16-naflex"
|
| 98 |
+
|
| 99 |
+
# Example query
|
| 100 |
+
IMAGE_PATH = "path/to/your/image.jpg"
|
| 101 |
+
QUESTION = "What is the capital of the country shown in this image?"
|
| 102 |
+
DATASET_NAME = "okvqa" # Which candidate pool to use
|
| 103 |
+
IMAGE_ID = "example_001"
|
| 104 |
+
|
| 105 |
+
print("="*80)
|
| 106 |
+
print("GRPO MODEL INFERENCE EXAMPLE")
|
| 107 |
+
print("="*80)
|
| 108 |
+
print()
|
| 109 |
+
|
| 110 |
+
# Load model
|
| 111 |
+
model, tokenizer = load_trained_model(MODEL_PATH)
|
| 112 |
+
|
| 113 |
+
# Load SigLIP for retrieval
|
| 114 |
+
print(f"Loading SigLIP from {SIGLIP_MODEL}...")
|
| 115 |
+
from transformers import AutoModel
|
| 116 |
+
siglip_model = AutoModel.from_pretrained(SIGLIP_MODEL).to("cuda")
|
| 117 |
+
siglip_processor = AutoProcessor.from_pretrained(SIGLIP_MODEL)
|
| 118 |
+
|
| 119 |
+
# Initialize environment (you need to provide candidate pools)
|
| 120 |
+
# For this example, we'll skip the full environment setup
|
| 121 |
+
print("\nNote: Full environment setup requires candidate pools.")
|
| 122 |
+
print("See train_grpo.py for complete example.")
|
| 123 |
+
print()
|
| 124 |
+
|
| 125 |
+
# Example without environment (direct generation)
|
| 126 |
+
print("Running direct generation (without retrieval)...")
|
| 127 |
+
print(f"Question: {question}")
|
| 128 |
+
|
| 129 |
+
# Format prompt
|
| 130 |
+
prompt = (
|
| 131 |
+
"You are a visual question answering assistant. "
|
| 132 |
+
"Answer the following question about the image.\n"
|
| 133 |
+
f"Question: {QUESTION}\n"
|
| 134 |
+
"Answer:"
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
# Generate
|
| 138 |
+
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
|
| 139 |
+
outputs = model.generate(
|
| 140 |
+
**inputs,
|
| 141 |
+
max_new_tokens=128,
|
| 142 |
+
temperature=0.7,
|
| 143 |
+
do_sample=True
|
| 144 |
+
)
|
| 145 |
+
answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 146 |
+
|
| 147 |
+
print(f"Answer: {answer}")
|
| 148 |
+
print()
|
| 149 |
+
|
| 150 |
+
print("="*80)
|
| 151 |
+
print("For full retrieval-augmented inference, use the RetrievalEnvironment")
|
| 152 |
+
print("with properly initialized candidate pools.")
|
| 153 |
+
print("="*80)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
if __name__ == "__main__":
|
| 157 |
+
main()
|
ICL/RL/key_metrics_20260220_152053.log
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
| 0 |
0%| | 1/1875 [00:32<16:39:52, 32.01s/it]
|
| 1 |
0%| | 2/1875 [01:16<20:27:46, 39.33s/it]
|
| 2 |
0%| | 3/1875 [02:00<21:34:29, 41.49s/it]
|
| 3 |
0%| | 4/1875 [02:22<17:39:25, 33.97s/it]
|
| 4 |
0%| | 5/1875 [03:04<19:02:40, 36.66s/it]
|
| 5 |
0%| | 6/1875 [03:47<20:08:12, 38.79s/it]
|
| 6 |
0%| | 7/1875 [04:24<19:52:10, 38.29s/it]
|
| 7 |
0%| | 8/1875 [05:06<20:28:55, 39.49s/it]
|
| 8 |
0%| | 9/1875 [05:53<21:37:22, 41.72s/it]
|
| 9 |
1%| | 10/1875 [06:23<19:46:45, 38.18s/it]
|
| 10 |
|
| 11 |
1%| | 10/1875 [06:23<19:46:45, 38.18s/it]
|
| 12 |
1%| | 11/1875 [07:03<20:03:40, 38.74s/it]
|
| 13 |
1%| | 12/1875 [07:33<18:38:14, 36.01s/it]
|
| 14 |
1%| | 13/1875 [07:57<16:42:30, 32.30s/it]
|
| 15 |
1%| | 14/1875 [08:24<15:53:26, 30.74s/it]
|
| 16 |
1%| | 15/1875 [08:54<15:50:11, 30.65s/it]
|
| 17 |
1%| | 16/1875 [09:24<15:44:34, 30.49s/it]
|
| 18 |
1%| | 17/1875 [09:51<15:08:09, 29.33s/it]
|
| 19 |
1%| | 18/1875 [10:12<13:54:45, 26.97s/it]
|
| 20 |
1%| | 19/1875 [10:52<15:49:15, 30.69s/it]
|
| 21 |
1%| | 20/1875 [11:20<15:22:48, 29.85s/it]
|
| 22 |
|
| 23 |
1%| | 20/1875 [11:20<15:22:48, 29.85s/it]
|
| 24 |
1%| | 21/1875 [11:44<14:27:00, 28.06s/it]
|
| 25 |
1%| | 22/1875 [12:09<14:04:22, 27.34s/it]
|
| 26 |
1%| | 23/1875 [12:52<16:26:19, 31.95s/it]
|
| 27 |
1%|▏ | 24/1875 [13:14<14:54:48, 29.00s/it]
|
| 28 |
1%|▏ | 25/1875 [13:59<17:22:59, 33.83s/it]
|
| 29 |
1%|▏ | 26/1875 [14:30<16:58:17, 33.04s/it]
|
| 30 |
1%|▏ | 27/1875 [15:10<17:55:23, 34.92s/it]
|
| 31 |
1%|▏ | 28/1875 [15:40<17:15:56, 33.65s/it]
|
| 32 |
2%|▏ | 29/1875 [16:12<16:53:40, 32.95s/it]
|
| 33 |
2%|▏ | 30/1875 [16:44<16:44:30, 32.67s/it]
|
| 34 |
|
| 35 |
2%|▏ | 30/1875 [16:44<16:44:30, 32.67s/it]
|
| 36 |
2%|▏ | 31/1875 [17:29<18:37:07, 36.35s/it]
|
| 37 |
2%|▏ | 32/1875 [17:54<16:52:23, 32.96s/it]
|
| 38 |
2%|▏ | 33/1875 [18:26<16:46:40, 32.79s/it]
|
| 39 |
2%|▏ | 34/1875 [19:03<17:27:09, 34.13s/it]
|
| 40 |
2%|▏ | 35/1875 [19:48<19:07:40, 37.42s/it]
|
| 41 |
2%|▏ | 36/1875 [20:12<16:57:07, 33.19s/it]
|
| 42 |
2%|▏ | 37/1875 [20:54<18:21:52, 35.97s/it]
|
| 43 |
2%|▏ | 38/1875 [21:34<18:53:11, 37.01s/it]
|
| 44 |
2%|▏ | 39/1875 [22:07<18:19:29, 35.93s/it]
|
| 45 |
2%|▏ | 40/1875 [22:47<18:52:31, 37.03s/it]
|
| 46 |
|
| 47 |
2%|▏ | 40/1875 [22:47<18:52:31, 37.03s/it]
|
| 48 |
2%|▏ | 41/1875 [23:26<19:14:46, 37.78s/it]
|
| 49 |
2%|▏ | 42/1875 [23:58<18:17:29, 35.92s/it]
|
| 50 |
2%|▏ | 43/1875 [24:40<19:17:58, 37.92s/it]
|
| 51 |
2%|▏ | 44/1875 [25:20<19:31:01, 38.37s/it]
|
| 52 |
2%|▏ | 45/1875 [26:05<20:38:02, 40.59s/it]
|
| 53 |
2%|▏ | 46/1875 [26:52<21:28:49, 42.28s/it]
|
| 54 |
3%|▎ | 47/1875 [27:39<22:11:46, 43.71s/it]
|
| 55 |
3%|▎ | 48/1875 [28:22<22:11:17, 43.72s/it]
|
| 56 |
3%|▎ | 49/1875 [28:51<19:49:53, 39.10s/it]
|
| 57 |
3%|▎ | 50/1875 [29:30<19:53:24, 39.24s/it]
|
| 58 |
|
| 59 |
3%|▎ | 50/1875 [29:30<19:53:24, 39.24s/it]
|
| 60 |
3%|▎ | 51/1875 [30:09<19:48:00, 39.08s/it]
|
| 61 |
3%|▎ | 52/1875 [30:54<20:37:08, 40.72s/it]
|
| 62 |
3%|▎ | 53/1875 [31:33<20:20:26, 40.19s/it]
|
| 63 |
3%|▎ | 54/1875 [32:13<20:24:26, 40.34s/it]
|
| 64 |
3%|▎ | 55/1875 [32:41<18:25:36, 36.45s/it]
|
| 65 |
3%|▎ | 56/1875 [33:05<16:36:53, 32.88s/it]
|
| 66 |
3%|▎ | 57/1875 [33:42<17:14:54, 34.16s/it]
|
| 67 |
3%|▎ | 58/1875 [34:22<18:03:37, 35.78s/it]
|
| 68 |
3%|▎ | 59/1875 [35:04<18:59:14, 37.64s/it]
|
| 69 |
3%|▎ | 60/1875 [35:31<17:23:41, 34.50s/it]
|
| 70 |
|
| 71 |
3%|▎ | 60/1875 [35:31<17:23:41, 34.50s/it]
|
| 72 |
3%|▎ | 61/1875 [35:57<16:08:17, 32.03s/it]
|
| 73 |
3%|▎ | 62/1875 [36:40<17:48:49, 35.37s/it]
|
| 74 |
3%|▎ | 63/1875 [37:19<18:19:46, 36.42s/it]
|
| 75 |
3%|▎ | 64/1875 [38:00<19:01:09, 37.81s/it]
|
| 76 |
3%|▎ | 65/1875 [38:41<19:26:56, 38.68s/it]
|
| 77 |
4%|▎ | 66/1875 [39:21<19:37:09, 39.04s/it]
|
| 78 |
4%|▎ | 67/1875 [40:00<19:38:11, 39.10s/it]
|
| 79 |
4%|▎ | 68/1875 [40:44<20:17:55, 40.44s/it]
|
| 80 |
4%|▎ | 69/1875 [41:25<20:23:31, 40.65s/it]
|
| 81 |
4%|▎ | 70/1875 [41:57<19:01:00, 37.93s/it]
|
| 82 |
|
| 83 |
4%|▎ | 70/1875 [41:57<19:01:00, 37.93s/it]
|
| 84 |
4%|▍ | 71/1875 [42:37<19:23:13, 38.69s/it]
|
| 85 |
4%|▍ | 72/1875 [43:03<17:31:05, 34.98s/it]
|
| 86 |
4%|▍ | 73/1875 [43:26<15:40:11, 31.30s/it]
|
| 87 |
4%|▍ | 74/1875 [43:50<14:36:17, 29.19s/it]
|
| 88 |
4%|▍ | 75/1875 [44:18<14:21:09, 28.71s/it]
|
| 89 |
4%|▍ | 76/1875 [44:59<16:11:41, 32.41s/it]
|
| 90 |
4%|▍ | 77/1875 [45:38<17:11:01, 34.41s/it]
|
| 91 |
4%|▍ | 78/1875 [46:20<18:18:18, 36.67s/it]
|
| 92 |
4%|▍ | 79/1875 [46:54<17:52:53, 35.84s/it]
|
| 93 |
4%|▍ | 80/1875 [47:34<18:28:54, 37.07s/it]
|
| 94 |
|
| 95 |
4%|▍ | 80/1875 [47:34<18:28:54, 37.07s/it]
|
| 96 |
4%|▍ | 81/1875 [47:58<16:31:27, 33.16s/it]
|
| 97 |
4%|▍ | 82/1875 [48:42<18:12:46, 36.57s/it]
|
| 98 |
4%|▍ | 83/1875 [49:19<18:17:12, 36.74s/it]
|
| 99 |
4%|▍ | 84/1875 [49:45<16:37:36, 33.42s/it]
|
| 100 |
5%|▍ | 85/1875 [50:12<15:40:58, 31.54s/it]
|
| 101 |
5%|▍ | 86/1875 [50:57<17:36:00, 35.42s/it]
|
| 102 |
5%|▍ | 87/1875 [51:24<16:22:18, 32.96s/it]
|
| 103 |
5%|▍ | 88/1875 [52:03<17:18:33, 34.87s/it]
|
| 104 |
5%|▍ | 89/1875 [52:28<15:46:47, 31.81s/it]
|
| 105 |
5%|▍ | 90/1875 [52:55<15:07:35, 30.51s/it]
|
| 106 |
|
| 107 |
5%|▍ | 90/1875 [52:55<15:07:35, 30.51s/it]
|
| 108 |
5%|▍ | 91/1875 [53:24<14:48:09, 29.87s/it]
|
| 109 |
5%|▍ | 92/1875 [53:50<14:16:30, 28.82s/it]
|
| 110 |
5%|▍ | 93/1875 [54:16<13:50:31, 27.96s/it]
|
| 111 |
5%|▌ | 94/1875 [54:44<13:46:50, 27.86s/it]
|
| 112 |
5%|▌ | 95/1875 [55:18<14:44:24, 29.81s/it]
|
| 113 |
5%|▌ | 96/1875 [55:42<13:49:10, 27.97s/it]
|
| 114 |
5%|▌ | 97/1875 [56:27<16:18:41, 33.03s/it]
|
| 115 |
5%|▌ | 98/1875 [56:51<15:02:24, 30.47s/it]
|
| 116 |
5%|▌ | 99/1875 [57:35<17:03:00, 34.56s/it]
|
| 117 |
5%|▌ | 100/1875 [58:10<17:08:10, 34.76s/it]
|
| 118 |
|
| 119 |
5%|▌ | 100/1875 [58:10<17:08:10, 34.76s/it]{'loss': '-0.05256', 'grad_norm': '5.906', 'learning_rate': '9e-07', 'num_tokens': '3.419e+05', 'completions/mean_length': '28.2', 'completions/min_length': '7.1', 'completions/max_length': '203.4', 'completions/clipped_ratio': '0.01094', 'completions/mean_terminated_length': '25.66', 'completions/min_terminated_length': '7.1', 'completions/max_terminated_length': '96.7', 'rewards/reward_function/mean': '-0.6679', 'rewards/reward_function/std': '0.5618', 'reward': '-0.6679', 'reward_std': '0.5618', 'frac_reward_zero_std': '0.01875', 'kl': '0.000573', 'entropy': '1.521', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '38.24', 'epoch': '0.016'}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
0%| | 1/1875 [00:32<16:39:52, 32.01s/it]
|
| 3 |
0%| | 2/1875 [01:16<20:27:46, 39.33s/it]
|
| 4 |
0%| | 3/1875 [02:00<21:34:29, 41.49s/it]
|
| 5 |
0%| | 4/1875 [02:22<17:39:25, 33.97s/it]
|
| 6 |
0%| | 5/1875 [03:04<19:02:40, 36.66s/it]
|
| 7 |
0%| | 6/1875 [03:47<20:08:12, 38.79s/it]
|
| 8 |
0%| | 7/1875 [04:24<19:52:10, 38.29s/it]
|
| 9 |
0%| | 8/1875 [05:06<20:28:55, 39.49s/it]
|
| 10 |
0%| | 9/1875 [05:53<21:37:22, 41.72s/it]
|
| 11 |
1%| | 10/1875 [06:23<19:46:45, 38.18s/it]
|
| 12 |
|
| 13 |
1%| | 10/1875 [06:23<19:46:45, 38.18s/it]
|
| 14 |
1%| | 11/1875 [07:03<20:03:40, 38.74s/it]
|
| 15 |
1%| | 12/1875 [07:33<18:38:14, 36.01s/it]
|
| 16 |
1%| | 13/1875 [07:57<16:42:30, 32.30s/it]
|
| 17 |
1%| | 14/1875 [08:24<15:53:26, 30.74s/it]
|
| 18 |
1%| | 15/1875 [08:54<15:50:11, 30.65s/it]
|
| 19 |
1%| | 16/1875 [09:24<15:44:34, 30.49s/it]
|
| 20 |
1%| | 17/1875 [09:51<15:08:09, 29.33s/it]
|
| 21 |
1%| | 18/1875 [10:12<13:54:45, 26.97s/it]
|
| 22 |
1%| | 19/1875 [10:52<15:49:15, 30.69s/it]
|
| 23 |
1%| | 20/1875 [11:20<15:22:48, 29.85s/it]
|
| 24 |
|
| 25 |
1%| | 20/1875 [11:20<15:22:48, 29.85s/it]
|
| 26 |
1%| | 21/1875 [11:44<14:27:00, 28.06s/it]
|
| 27 |
1%| | 22/1875 [12:09<14:04:22, 27.34s/it]
|
| 28 |
1%| | 23/1875 [12:52<16:26:19, 31.95s/it]
|
| 29 |
1%|▏ | 24/1875 [13:14<14:54:48, 29.00s/it]
|
| 30 |
1%|▏ | 25/1875 [13:59<17:22:59, 33.83s/it]
|
| 31 |
1%|▏ | 26/1875 [14:30<16:58:17, 33.04s/it]
|
| 32 |
1%|▏ | 27/1875 [15:10<17:55:23, 34.92s/it]
|
| 33 |
1%|▏ | 28/1875 [15:40<17:15:56, 33.65s/it]
|
| 34 |
2%|▏ | 29/1875 [16:12<16:53:40, 32.95s/it]
|
| 35 |
2%|▏ | 30/1875 [16:44<16:44:30, 32.67s/it]
|
| 36 |
|
| 37 |
2%|▏ | 30/1875 [16:44<16:44:30, 32.67s/it]
|
| 38 |
2%|▏ | 31/1875 [17:29<18:37:07, 36.35s/it]
|
| 39 |
2%|▏ | 32/1875 [17:54<16:52:23, 32.96s/it]
|
| 40 |
2%|▏ | 33/1875 [18:26<16:46:40, 32.79s/it]
|
| 41 |
2%|▏ | 34/1875 [19:03<17:27:09, 34.13s/it]
|
| 42 |
2%|▏ | 35/1875 [19:48<19:07:40, 37.42s/it]
|
| 43 |
2%|▏ | 36/1875 [20:12<16:57:07, 33.19s/it]
|
| 44 |
2%|▏ | 37/1875 [20:54<18:21:52, 35.97s/it]
|
| 45 |
2%|▏ | 38/1875 [21:34<18:53:11, 37.01s/it]
|
| 46 |
2%|▏ | 39/1875 [22:07<18:19:29, 35.93s/it]
|
| 47 |
2%|▏ | 40/1875 [22:47<18:52:31, 37.03s/it]
|
| 48 |
|
| 49 |
2%|▏ | 40/1875 [22:47<18:52:31, 37.03s/it]
|
| 50 |
2%|▏ | 41/1875 [23:26<19:14:46, 37.78s/it]
|
| 51 |
2%|▏ | 42/1875 [23:58<18:17:29, 35.92s/it]
|
| 52 |
2%|▏ | 43/1875 [24:40<19:17:58, 37.92s/it]
|
| 53 |
2%|▏ | 44/1875 [25:20<19:31:01, 38.37s/it]
|
| 54 |
2%|▏ | 45/1875 [26:05<20:38:02, 40.59s/it]
|
| 55 |
2%|▏ | 46/1875 [26:52<21:28:49, 42.28s/it]
|
| 56 |
3%|▎ | 47/1875 [27:39<22:11:46, 43.71s/it]
|
| 57 |
3%|▎ | 48/1875 [28:22<22:11:17, 43.72s/it]
|
| 58 |
3%|▎ | 49/1875 [28:51<19:49:53, 39.10s/it]
|
| 59 |
3%|▎ | 50/1875 [29:30<19:53:24, 39.24s/it]
|
| 60 |
|
| 61 |
3%|▎ | 50/1875 [29:30<19:53:24, 39.24s/it]
|
| 62 |
3%|▎ | 51/1875 [30:09<19:48:00, 39.08s/it]
|
| 63 |
3%|▎ | 52/1875 [30:54<20:37:08, 40.72s/it]
|
| 64 |
3%|▎ | 53/1875 [31:33<20:20:26, 40.19s/it]
|
| 65 |
3%|▎ | 54/1875 [32:13<20:24:26, 40.34s/it]
|
| 66 |
3%|▎ | 55/1875 [32:41<18:25:36, 36.45s/it]
|
| 67 |
3%|▎ | 56/1875 [33:05<16:36:53, 32.88s/it]
|
| 68 |
3%|▎ | 57/1875 [33:42<17:14:54, 34.16s/it]
|
| 69 |
3%|▎ | 58/1875 [34:22<18:03:37, 35.78s/it]
|
| 70 |
3%|▎ | 59/1875 [35:04<18:59:14, 37.64s/it]
|
| 71 |
3%|▎ | 60/1875 [35:31<17:23:41, 34.50s/it]
|
| 72 |
|
| 73 |
3%|▎ | 60/1875 [35:31<17:23:41, 34.50s/it]
|
| 74 |
3%|▎ | 61/1875 [35:57<16:08:17, 32.03s/it]
|
| 75 |
3%|▎ | 62/1875 [36:40<17:48:49, 35.37s/it]
|
| 76 |
3%|▎ | 63/1875 [37:19<18:19:46, 36.42s/it]
|
| 77 |
3%|▎ | 64/1875 [38:00<19:01:09, 37.81s/it]
|
| 78 |
3%|▎ | 65/1875 [38:41<19:26:56, 38.68s/it]
|
| 79 |
4%|▎ | 66/1875 [39:21<19:37:09, 39.04s/it]
|
| 80 |
4%|▎ | 67/1875 [40:00<19:38:11, 39.10s/it]
|
| 81 |
4%|▎ | 68/1875 [40:44<20:17:55, 40.44s/it]
|
| 82 |
4%|▎ | 69/1875 [41:25<20:23:31, 40.65s/it]
|
| 83 |
4%|▎ | 70/1875 [41:57<19:01:00, 37.93s/it]
|
| 84 |
|
| 85 |
4%|▎ | 70/1875 [41:57<19:01:00, 37.93s/it]
|
| 86 |
4%|▍ | 71/1875 [42:37<19:23:13, 38.69s/it]
|
| 87 |
4%|▍ | 72/1875 [43:03<17:31:05, 34.98s/it]
|
| 88 |
4%|▍ | 73/1875 [43:26<15:40:11, 31.30s/it]
|
| 89 |
4%|▍ | 74/1875 [43:50<14:36:17, 29.19s/it]
|
| 90 |
4%|▍ | 75/1875 [44:18<14:21:09, 28.71s/it]
|
| 91 |
4%|▍ | 76/1875 [44:59<16:11:41, 32.41s/it]
|
| 92 |
4%|▍ | 77/1875 [45:38<17:11:01, 34.41s/it]
|
| 93 |
4%|▍ | 78/1875 [46:20<18:18:18, 36.67s/it]
|
| 94 |
4%|▍ | 79/1875 [46:54<17:52:53, 35.84s/it]
|
| 95 |
4%|▍ | 80/1875 [47:34<18:28:54, 37.07s/it]
|
| 96 |
|
| 97 |
4%|▍ | 80/1875 [47:34<18:28:54, 37.07s/it]
|
| 98 |
4%|▍ | 81/1875 [47:58<16:31:27, 33.16s/it]
|
| 99 |
4%|▍ | 82/1875 [48:42<18:12:46, 36.57s/it]
|
| 100 |
4%|▍ | 83/1875 [49:19<18:17:12, 36.74s/it]
|
| 101 |
4%|▍ | 84/1875 [49:45<16:37:36, 33.42s/it]
|
| 102 |
5%|▍ | 85/1875 [50:12<15:40:58, 31.54s/it]
|
| 103 |
5%|▍ | 86/1875 [50:57<17:36:00, 35.42s/it]
|
| 104 |
5%|▍ | 87/1875 [51:24<16:22:18, 32.96s/it]
|
| 105 |
5%|▍ | 88/1875 [52:03<17:18:33, 34.87s/it]
|
| 106 |
5%|▍ | 89/1875 [52:28<15:46:47, 31.81s/it]
|
| 107 |
5%|▍ | 90/1875 [52:55<15:07:35, 30.51s/it]
|
| 108 |
|
| 109 |
5%|▍ | 90/1875 [52:55<15:07:35, 30.51s/it]
|
| 110 |
5%|▍ | 91/1875 [53:24<14:48:09, 29.87s/it]
|
| 111 |
5%|▍ | 92/1875 [53:50<14:16:30, 28.82s/it]
|
| 112 |
5%|▍ | 93/1875 [54:16<13:50:31, 27.96s/it]
|
| 113 |
5%|▌ | 94/1875 [54:44<13:46:50, 27.86s/it]
|
| 114 |
5%|▌ | 95/1875 [55:18<14:44:24, 29.81s/it]
|
| 115 |
5%|▌ | 96/1875 [55:42<13:49:10, 27.97s/it]
|
| 116 |
5%|▌ | 97/1875 [56:27<16:18:41, 33.03s/it]
|
| 117 |
5%|▌ | 98/1875 [56:51<15:02:24, 30.47s/it]
|
| 118 |
5%|▌ | 99/1875 [57:35<17:03:00, 34.56s/it]
|
| 119 |
5%|▌ | 100/1875 [58:10<17:08:10, 34.76s/it]
|
| 120 |
|
| 121 |
5%|▌ | 100/1875 [58:10<17:08:10, 34.76s/it]{'loss': '-0.05256', 'grad_norm': '5.906', 'learning_rate': '9e-07', 'num_tokens': '3.419e+05', 'completions/mean_length': '28.2', 'completions/min_length': '7.1', 'completions/max_length': '203.4', 'completions/clipped_ratio': '0.01094', 'completions/mean_terminated_length': '25.66', 'completions/min_terminated_length': '7.1', 'completions/max_terminated_length': '96.7', 'rewards/reward_function/mean': '-0.6679', 'rewards/reward_function/std': '0.5618', 'reward': '-0.6679', 'reward_std': '0.5618', 'frac_reward_zero_std': '0.01875', 'kl': '0.000573', 'entropy': '1.521', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '38.24', 'epoch': '0.016'}
|
| 122 |
+
{'loss': '0.006149', 'grad_norm': '7.188', 'learning_rate': '1.9e-06', 'num_tokens': '6.994e+05', 'completions/mean_length': '25.21', 'completions/min_length': '7', 'completions/max_length': '97.2', 'completions/clipped_ratio': '0', 'completions/mean_terminated_length': '25.21', 'completions/min_terminated_length': '7', 'completions/max_terminated_length': '97.2', 'rewards/reward_function/mean': '-0.6674', 'rewards/reward_function/std': '0.5884', 'reward': '-0.6674', 'reward_std': '0.5884', 'frac_reward_zero_std': '0.025', 'kl': '0.0006873', 'entropy': '0.5141', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '29.57', 'epoch': '0.032'}
|
| 123 |
+
{'loss': '0.00594', 'grad_norm': '5.969', 'learning_rate': '2.9e-06', 'num_tokens': '1.024e+06', 'completions/mean_length': '29.12', 'completions/min_length': '7.5', 'completions/max_length': '143', 'completions/clipped_ratio': '0.004687', 'completions/mean_terminated_length': '28.04', 'completions/min_terminated_length': '7.5', 'completions/max_terminated_length': '100.2', 'rewards/reward_function/mean': '-0.6755', 'rewards/reward_function/std': '0.5826', 'reward': '-0.6755', 'reward_std': '0.5826', 'frac_reward_zero_std': '0.03125', 'kl': '0.002043', 'entropy': '0.8444', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '32.29', 'epoch': '0.048'}
|
| 124 |
+
{'loss': '-0.02556', 'grad_norm': '7.344', 'learning_rate': '3.9e-06', 'num_tokens': '1.341e+06', 'completions/mean_length': '32.32', 'completions/min_length': '7.6', 'completions/max_length': '191', 'completions/clipped_ratio': '0.007812', 'completions/mean_terminated_length': '30.55', 'completions/min_terminated_length': '7.6', 'completions/max_terminated_length': '129.5', 'rewards/reward_function/mean': '-0.6495', 'rewards/reward_function/std': '0.6234', 'reward': '-0.6495', 'reward_std': '0.6234', 'frac_reward_zero_std': '0.05625', 'kl': '0.006311', 'entropy': '1.191', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '36.22', 'epoch': '0.064'}
|
| 125 |
+
{'loss': '-0.01995', 'grad_norm': '5.094', 'learning_rate': '4.9e-06', 'num_tokens': '1.695e+06', 'completions/mean_length': '35.92', 'completions/min_length': '7.3', 'completions/max_length': '226.1', 'completions/clipped_ratio': '0.02031', 'completions/mean_terminated_length': '31.4', 'completions/min_terminated_length': '7.3', 'completions/max_terminated_length': '103.7', 'rewards/reward_function/mean': '-0.5803', 'rewards/reward_function/std': '0.6009', 'reward': '-0.5803', 'reward_std': '0.6009', 'frac_reward_zero_std': '0.0875', 'kl': '0.02276', 'entropy': '1.659', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '40.29', 'epoch': '0.08'}
|
| 126 |
+
{'loss': '0.08347', 'grad_norm': '7.656', 'learning_rate': '5e-06', 'num_tokens': '2.039e+06', 'completions/mean_length': '39.98', 'completions/min_length': '9.5', 'completions/max_length': '208.9', 'completions/clipped_ratio': '0.02969', 'completions/mean_terminated_length': '33.43', 'completions/min_terminated_length': '9.5', 'completions/max_terminated_length': '122.4', 'rewards/reward_function/mean': '-0.3908', 'rewards/reward_function/std': '0.6835', 'reward': '-0.3908', 'reward_std': '0.6835', 'frac_reward_zero_std': '0.2062', 'kl': '0.07372', 'entropy': '1.91', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '35.99', 'epoch': '0.096'}
|
| 127 |
+
{'loss': '0.1118', 'grad_norm': '4.375', 'learning_rate': '5e-06', 'num_tokens': '2.398e+06', 'completions/mean_length': '37.39', 'completions/min_length': '8.1', 'completions/max_length': '222.2', 'completions/clipped_ratio': '0.02031', 'completions/mean_terminated_length': '32.87', 'completions/min_terminated_length': '8.1', 'completions/max_terminated_length': '113.3', 'rewards/reward_function/mean': '-0.3459', 'rewards/reward_function/std': '0.7209', 'reward': '-0.3459', 'reward_std': '0.7209', 'frac_reward_zero_std': '0.1812', 'kl': '0.05823', 'entropy': '1.797', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '38.46', 'epoch': '0.112'}
|
| 128 |
+
{'loss': '0.08823', 'grad_norm': '3.328', 'learning_rate': '5e-06', 'num_tokens': '2.736e+06', 'completions/mean_length': '33.79', 'completions/min_length': '9.6', 'completions/max_length': '181.3', 'completions/clipped_ratio': '0.01094', 'completions/mean_terminated_length': '31.35', 'completions/min_terminated_length': '9.6', 'completions/max_terminated_length': '113.7', 'rewards/reward_function/mean': '-0.2964', 'rewards/reward_function/std': '0.7092', 'reward': '-0.2964', 'reward_std': '0.7092', 'frac_reward_zero_std': '0.2562', 'kl': '0.07571', 'entropy': '1.331', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '33.64', 'epoch': '0.128'}
|
| 129 |
+
{'loss': '0.0669', 'grad_norm': '3.906', 'learning_rate': '5e-06', 'num_tokens': '3.073e+06', 'completions/mean_length': '30.91', 'completions/min_length': '9.5', 'completions/max_length': '156.7', 'completions/clipped_ratio': '0.009375', 'completions/mean_terminated_length': '28.78', 'completions/min_terminated_length': '9.5', 'completions/max_terminated_length': '95.5', 'rewards/reward_function/mean': '-0.2708', 'rewards/reward_function/std': '0.7212', 'reward': '-0.2708', 'reward_std': '0.7212', 'frac_reward_zero_std': '0.3125', 'kl': '0.06573', 'entropy': '1.085', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '32.09', 'epoch': '0.144'}
|
| 130 |
+
{'loss': '0.04796', 'grad_norm': '8.562', 'learning_rate': '5e-06', 'num_tokens': '3.432e+06', 'completions/mean_length': '27.49', 'completions/min_length': '9.2', 'completions/max_length': '127', 'completions/clipped_ratio': '0.003125', 'completions/mean_terminated_length': '26.78', 'completions/min_terminated_length': '9.2', 'completions/max_terminated_length': '90.9', 'rewards/reward_function/mean': '-0.2074', 'rewards/reward_function/std': '0.7264', 'reward': '-0.2074', 'reward_std': '0.7264', 'frac_reward_zero_std': '0.325', 'kl': '0.06459', 'entropy': '0.6255', 'clip_ratio/low_mean': '0', 'clip_ratio/low_min': '0', 'clip_ratio/high_mean': '0', 'clip_ratio/high_max': '0', 'clip_ratio/region_mean': '0', 'step_time': '31.41', 'epoch': '0.16'}
|
ICL/RL/key_metrics_20260224_094601.log
ADDED
|
File without changes
|
ICL/RL/key_metrics_20260224_133510.log
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
ICL/RL/plan.md
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# GRPO 训练方案:基于 M3IT 与 SigLIP 的模型驱动迭代检索
|
| 2 |
+
|
| 3 |
+
## 第一部分:数据战略 (The Data Curriculum)
|
| 4 |
+
|
| 5 |
+
为了训练模型"知进退"(Know when to fold, when to hold),我们必须构建一个 **1:1 的对抗数据池**。
|
| 6 |
+
|
| 7 |
+
### 1. 数据来源:M3IT 数据集
|
| 8 |
+
你需要把数据分为两类,混合打乱后用于 RL 训练。
|
| 9 |
+
|
| 10 |
+
* **Positive Set (必须检索类) —— 占比 50%**
|
| 11 |
+
* **任务定义:** 光看像素无法回答,需要外部知识或类似案例。
|
| 12 |
+
* **精选子集:**
|
| 13 |
+
* **OK-VQA / A-OKVQA:** 知识密集型问答。
|
| 14 |
+
* **ScienceQA:** 科学常识推断。
|
| 15 |
+
* **ViQuAE (如果有):** 实体知识检索。
|
| 16 |
+
* **预期行为:** 模型应输出 `<RET>`,且描述越准越好。
|
| 17 |
+
|
| 18 |
+
* **Negative Set (禁止/少检索类) —— 占比 50%**
|
| 19 |
+
* **任务定义:** 答案就在图里,检索纯属浪费时间,甚至引入噪音。
|
| 20 |
+
* **精选子集:**
|
| 21 |
+
* **TextVQA / OCR-VQA:** 看图读字。检索相似图通常会导致读错字。
|
| 22 |
+
* **VQA-v2:** 基础视觉识别(颜色、数量、位置)。
|
| 23 |
+
* **CLEVR:** 纯逻辑推理。
|
| 24 |
+
* **预期行为:** 模型应直接 `<ANS>`,或者在尝试 1 次检索后发现没用立刻 `<ANS>`。
|
| 25 |
+
|
| 26 |
+
---
|
| 27 |
+
|
| 28 |
+
## 第二部分:交互环境 (The Rollout Pipeline)
|
| 29 |
+
|
| 30 |
+
这是 GRPO 采样时发生的 **动态过程**。每次 rollout 都是一次独立的探索。
|
| 31 |
+
|
| 32 |
+
### 核心逻辑:While 循环 + 动态 Top-1 注入
|
| 33 |
+
|
| 34 |
+
1. **初始状态:** 输入 `[User Image] + [Question]`。
|
| 35 |
+
2. **模型生成 (Action):**
|
| 36 |
+
* **路径 A (直答):** 模型输出 `<ANS> 答案` $\rightarrow$ **结束**。
|
| 37 |
+
* **路径 B (检索):** 模型输出 `<RET> 描述文本` $\rightarrow$ **暂停**。
|
| 38 |
+
3. **环境反馈 (Environment):**
|
| 39 |
+
* 截取 `描述文本`。
|
| 40 |
+
* 使用 SigLIP 计算该文本与 **候选图库 (Candidate Pool)** 的相似度。
|
| 41 |
+
* **只取 Top-1** 最相似的图片及其 Caption/Answer。
|
| 42 |
+
* **拼接历史:** 将 `<RET>...` 和 `System: Retrieved Image...` 加入对话历史。
|
| 43 |
+
4. **循环判断:**
|
| 44 |
+
* 将新的历史喂回模型,进入下一轮。
|
| 45 |
+
* **最大轮次限制 (Max Turns):** 设为 3。如果到了 3 轮还没 `<ANS>`,强制截断并给予重罚。
|
| 46 |
+
|
| 47 |
+
### 候选图库 (Candidate Pool) 构建策略
|
| 48 |
+
|
| 49 |
+
**✅ 最终方案:Sub-dataset Pool (分任务池)**
|
| 50 |
+
|
| 51 |
+
* **范围:** 当训练 ok-vqa 的样本时,检索池只包含 ok-vqa 的所有图片。
|
| 52 |
+
* **原因:** 全局 M3IT 池太大(百万级),检索慢且干扰多;分任务池既快又准。
|
| 53 |
+
* **Self-Exclusion(必须):** 检索时必须排除 Query Image 本身。
|
| 54 |
+
* 如果不排除,模型只要学会描述原图,SigLIP 搜原图的相似度永远是 1.0。
|
| 55 |
+
* 模型会学会"抄袭原图"而不是"寻找相似案例"。
|
| 56 |
+
* **实现:** 在检索时,如果 Top-1 的 Image ID == Query Image ID,则顺延取 Top-2。
|
| 57 |
+
|
| 58 |
+
---
|
| 59 |
+
|
| 60 |
+
## 第三部分:奖励函数设计 (The Robust Reward) - **已根据数据分析修正**
|
| 61 |
+
|
| 62 |
+
### ⚠️ 重要修正:移除 R_faith
|
| 63 |
+
|
| 64 |
+
**数据分析结果(500样本):**
|
| 65 |
+
- Image-to-Caption 相似度:Mean = 0.0732, Max = 0.1797
|
| 66 |
+
- **结论:** SigLIP 的文本编码器对短答案的编码效果极差,分数全是噪声
|
| 67 |
+
- **R_faith 已被证明无效,完全移除**
|
| 68 |
+
|
| 69 |
+
### 最终奖励公式
|
| 70 |
+
|
| 71 |
+
$$R_{total} = R_{outcome} + \text{Gate}(IsCorrect) \times R_{rel} + R_{penalty}$$
|
| 72 |
+
|
| 73 |
+
### 1. $R_{outcome}$:结果硬指标 (The Gatekeeper)
|
| 74 |
+
* **逻辑:** 无论过程多花哨,答错一律没分。
|
| 75 |
+
* **计算:**
|
| 76 |
+
* **Answer Correct (F1/Exact Match):** **+1.0**
|
| 77 |
+
* **Answer Wrong:** **-1.0**
|
| 78 |
+
|
| 79 |
+
### 2. $R_{penalty}$:阶梯式步数惩罚 (Efficiency)
|
| 80 |
+
* **逻辑:** 检索是有成本的。用的 Shot 越多,扣分越狠(非线性)。
|
| 81 |
+
* **计算:**
|
| 82 |
+
* **0-shot (直答):** $0.0$
|
| 83 |
+
* **1-shot:** $-0.1$
|
| 84 |
+
* **2-shot:** $-0.25$
|
| 85 |
+
* **3-shot:** $-0.5$
|
| 86 |
+
* **达到 max_turns 仍未 <ANS>:** $-2.0$ (死循环惩罚)
|
| 87 |
+
* **作用:** 如果模型能用 1-shot 答对 (1.0 - 0.1 + R_rel),它绝不会去用 3-shot (1.0 - 0.5 + R_rel)。
|
| 88 |
+
|
| 89 |
+
### 3. $R_{rel}$:检索有效性 (Relevance) - **唯一的视觉过程分**
|
| 90 |
+
|
| 91 |
+
* **问题:** 确保检索回来的图跟原图是相似的(证明描述有效)。
|
| 92 |
+
* **计算:** `scale_siglip_score(SigLIP_Score(原图, 检索到的 Top-1 图))`
|
| 93 |
+
* **归一化函数:**
|
| 94 |
+
```python
|
| 95 |
+
def scale_siglip_score(raw_score, min_th=0.33, max_th=0.97):
|
| 96 |
+
"""
|
| 97 |
+
基于数据分析结果:
|
| 98 |
+
- min_th = 0.33 (P05 of same-subdir pairs)
|
| 99 |
+
- max_th = 0.97 (P95 of same-subdir pairs)
|
| 100 |
+
"""
|
| 101 |
+
scaled = (raw_score - min_th) / (max_th - min_th)
|
| 102 |
+
return max(0.0, min(1.0, scaled))
|
| 103 |
+
```
|
| 104 |
+
* **门控机制:** 只有当 $R_{outcome} > 0$ (答对了) 时,$R_{rel}$ 才生效!
|
| 105 |
+
|
| 106 |
+
### 4. 强制截断处理
|
| 107 |
+
|
| 108 |
+
* **触发条件:** 模型连续输出 3 次 `<RET>` 仍未 `<ANS>`
|
| 109 |
+
* **惩罚值:** $R_{penalty} = -2.0$
|
| 110 |
+
* **逻辑:** 答错是 -1.0��死循环不仅没答对,还浪费了 3 倍算力,必须比直接答错惩罚更重
|
| 111 |
+
* **实现:** 训练时直接截断 Trajectory,标记 Reward = -2.0,不强制输出 `<ANS>`
|
| 112 |
+
|
| 113 |
+
---
|
| 114 |
+
|
| 115 |
+
## 第四部分:GRPO 训练推演 (Case Study) - **基于修正后的奖励**
|
| 116 |
+
|
| 117 |
+
在这个方案下,模型是如何进化的?我们模拟一个 Batch (Group Size = 8) 的竞争情况:
|
| 118 |
+
|
| 119 |
+
* **场景:** 一道中等难度的题(需要 1 张参考图)。
|
| 120 |
+
|
| 121 |
+
### 样本对比
|
| 122 |
+
|
| 123 |
+
| 样本 | 行为 | R_outcome | R_penalty | R_rel | 总分 |
|
| 124 |
+
|------|------|-----------|-----------|-------|------|
|
| 125 |
+
| 1. 直答错 | 没检索,答错 | -1.0 | 0.0 | 0.0 | **-1.0** |
|
| 126 |
+
| 2. 1-shot完美 | 描述准→搜到好图→答对 | +1.0 | -0.1 | +1.0 | **1.9** |
|
| 127 |
+
| 3. 1-shot一般 | 搜到中等图→答对 | +1.0 | -0.1 | +0.5 | **1.4** |
|
| 128 |
+
| 4. 1-shot烂图 | 搜到烂图→蒙对 | +1.0 | -0.1 | +0.0 | **0.9** |
|
| 129 |
+
| 5. 3-shot完美 | 搜了3次→答对 | +1.0 | -0.5 | +1.0 | **1.5** |
|
| 130 |
+
| 6. 直答对 | 没检索,答对 | +1.0 | 0.0 | 0.0 | **1.0** |
|
| 131 |
+
| 7. 死循环 | 3次RET未ANS | -1.0 | -2.0 | 0.0 | **-3.0** |
|
| 132 |
+
|
| 133 |
+
### GRPO 的梯度更新方向
|
| 134 |
+
|
| 135 |
+
模型会发现 **样本 2 (1.9分)** 是绝对的赢家:
|
| 136 |
+
1. 它比样本 1 强:说明**该搜的时候必须搜**
|
| 137 |
+
2. 它比样本 5 强:说明**能搜 1 次别搜 3 次**(效率优先)
|
| 138 |
+
3. 它比样本 4 强:说明**描述必须写得准**(R_rel 高才能拿高分)
|
| 139 |
+
4. 它比样本 6 强:说明**对于需要检索的题,检索比瞎猜好**
|
| 140 |
+
|
| 141 |
+
---
|
| 142 |
+
|
| 143 |
+
## 第五部分:实现细节
|
| 144 |
+
|
| 145 |
+
### 训练配置
|
| 146 |
+
* **模型:** Qwen3-VL-8B (从 SFT checkpoint 开始)
|
| 147 |
+
* **SFT Checkpoint:** `/workspace/xiaobin/SFT_model/hf_qwen3vl_siglip_vqa_iter_0000881`
|
| 148 |
+
* **硬件:** 8×H100 (80GB)
|
| 149 |
+
* **Group Size:** 8 (每个query采样8条轨迹)
|
| 150 |
+
* **框架:** Megatron-BPLM
|
| 151 |
+
* **并行策略:** TP=8, DP=1
|
| 152 |
+
|
| 153 |
+
### 数据分布验证
|
| 154 |
+
* **SigLIP Image-to-Image (Same Subdir):**
|
| 155 |
+
* Mean: 0.5188, Std: 0.1689
|
| 156 |
+
* P05: 0.331, P95: 0.969
|
| 157 |
+
* ✅ 区分度良好,适合作为奖励信号
|
| 158 |
+
|
| 159 |
+
### 总结执行清单
|
| 160 |
+
|
| 161 |
+
1. **数据清洗:** 将 M3IT 按 50:50 拆分为 Positive (OK-VQA等) 和 Negative (TextVQA等)
|
| 162 |
+
2. **Pipeline 实现:** 写好 `while` 循环,每次只注入 SigLIP 找出的 Top-1(排除原图)
|
| 163 |
+
3. **Reward 实现:**
|
| 164 |
+
- 移除 R_faith
|
| 165 |
+
- 使用 R_rel (min_th=0.33, max_th=0.97)
|
| 166 |
+
- 严格执行门控机制和非线性步数惩罚
|
| 167 |
+
4. **启动:** 这是一个自我博弈的过程,模型从"胡乱检索"逐渐收敛到"精准猎杀"
|
ICL/RL/plot_metrics.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Parse GRPO training log and plot key metrics."""
|
| 3 |
+
|
| 4 |
+
import re
|
| 5 |
+
import ast
|
| 6 |
+
import os
|
| 7 |
+
import matplotlib
|
| 8 |
+
matplotlib.use('Agg')
|
| 9 |
+
import matplotlib.pyplot as plt
|
| 10 |
+
import numpy as np
|
| 11 |
+
|
| 12 |
+
LOG_FILE = "/workspace/xiaobin/RL/train_grpo_20260224_133510.log"
|
| 13 |
+
SAVE_DIR = "/workspace/xiaobin/RL/plots"
|
| 14 |
+
os.makedirs(SAVE_DIR, exist_ok=True)
|
| 15 |
+
|
| 16 |
+
# Parse all metric lines
|
| 17 |
+
records = []
|
| 18 |
+
with open(LOG_FILE, "r") as f:
|
| 19 |
+
for line in f:
|
| 20 |
+
line = line.strip()
|
| 21 |
+
if line.startswith("{'loss':") and "'epoch'" in line:
|
| 22 |
+
# skip the final summary line (has 'train_runtime')
|
| 23 |
+
if "'train_runtime'" in line:
|
| 24 |
+
continue
|
| 25 |
+
try:
|
| 26 |
+
d = ast.literal_eval(line)
|
| 27 |
+
# convert all values to float
|
| 28 |
+
d = {k: float(v) for k, v in d.items()}
|
| 29 |
+
records.append(d)
|
| 30 |
+
except Exception:
|
| 31 |
+
pass
|
| 32 |
+
|
| 33 |
+
print(f"Parsed {len(records)} training steps")
|
| 34 |
+
|
| 35 |
+
if not records:
|
| 36 |
+
print("No records found!")
|
| 37 |
+
exit(1)
|
| 38 |
+
|
| 39 |
+
# Extract steps (1-indexed) and epochs
|
| 40 |
+
steps = list(range(1, len(records) + 1))
|
| 41 |
+
epochs = [r['epoch'] for r in records]
|
| 42 |
+
|
| 43 |
+
# Define which metrics to plot, grouped logically
|
| 44 |
+
plot_configs = [
|
| 45 |
+
# (filename, title, metrics_list, ylabel)
|
| 46 |
+
("loss.png", "Training Loss", [("loss", "Loss")], "Loss"),
|
| 47 |
+
("grad_norm.png", "Gradient Norm", [("grad_norm", "Grad Norm")], "Grad Norm"),
|
| 48 |
+
("learning_rate.png", "Learning Rate", [("learning_rate", "LR")], "Learning Rate"),
|
| 49 |
+
("reward.png", "Reward (mean & std)", [("reward", "Reward Mean"), ("reward_std", "Reward Std")], "Reward"),
|
| 50 |
+
("reward_detail.png", "Reward Function Detail",
|
| 51 |
+
[("rewards/reward_function/mean", "Mean"), ("rewards/reward_function/std", "Std")], "Reward"),
|
| 52 |
+
("kl_divergence.png", "KL Divergence", [("kl", "KL")], "KL"),
|
| 53 |
+
("entropy.png", "Entropy", [("entropy", "Entropy")], "Entropy"),
|
| 54 |
+
("completion_length.png", "Completion Length",
|
| 55 |
+
[("completions/mean_length", "Mean"), ("completions/min_length", "Min"), ("completions/max_length", "Max")],
|
| 56 |
+
"Token Length"),
|
| 57 |
+
("completion_terminated_length.png", "Terminated Completion Length",
|
| 58 |
+
[("completions/mean_terminated_length", "Mean"), ("completions/min_terminated_length", "Min"),
|
| 59 |
+
("completions/max_terminated_length", "Max")], "Token Length"),
|
| 60 |
+
("clipped_ratio.png", "Completions Clipped Ratio", [("completions/clipped_ratio", "Clipped Ratio")], "Ratio"),
|
| 61 |
+
("frac_reward_zero_std.png", "Fraction Reward Zero Std", [("frac_reward_zero_std", "Frac Zero Std")], "Fraction"),
|
| 62 |
+
("clip_ratio_high.png", "Clip Ratio (High)",
|
| 63 |
+
[("clip_ratio/high_mean", "High Mean"), ("clip_ratio/high_max", "High Max")], "Ratio"),
|
| 64 |
+
("clip_ratio_low.png", "Clip Ratio (Low)",
|
| 65 |
+
[("clip_ratio/low_mean", "Low Mean"), ("clip_ratio/low_min", "Low Min")], "Ratio"),
|
| 66 |
+
("step_time.png", "Step Time", [("step_time", "Step Time (s)")], "Seconds"),
|
| 67 |
+
]
|
| 68 |
+
|
| 69 |
+
# Smoothing helper (simple moving average)
|
| 70 |
+
def smooth(y, window=21):
|
| 71 |
+
if len(y) < window:
|
| 72 |
+
return y
|
| 73 |
+
cumsum = np.cumsum(np.insert(y, 0, 0))
|
| 74 |
+
return (cumsum[window:] - cumsum[:-window]) / window
|
| 75 |
+
|
| 76 |
+
# Plot each config
|
| 77 |
+
for fname, title, metrics, ylabel in plot_configs:
|
| 78 |
+
fig, ax = plt.subplots(figsize=(12, 5))
|
| 79 |
+
for key, label in metrics:
|
| 80 |
+
vals = [r.get(key, float('nan')) for r in records]
|
| 81 |
+
ax.plot(steps, vals, alpha=0.3, linewidth=0.8)
|
| 82 |
+
# add smoothed line
|
| 83 |
+
s = smooth(np.array(vals))
|
| 84 |
+
offset = (len(vals) - len(s)) // 2
|
| 85 |
+
ax.plot(steps[offset:offset+len(s)], s, linewidth=2, label=label + " (smoothed)")
|
| 86 |
+
ax.set_xlabel("Step")
|
| 87 |
+
ax.set_ylabel(ylabel)
|
| 88 |
+
ax.set_title(title)
|
| 89 |
+
ax.legend()
|
| 90 |
+
ax.grid(True, alpha=0.3)
|
| 91 |
+
plt.tight_layout()
|
| 92 |
+
path = os.path.join(SAVE_DIR, fname)
|
| 93 |
+
fig.savefig(path, dpi=150)
|
| 94 |
+
plt.close(fig)
|
| 95 |
+
print(f"Saved: {path}")
|
| 96 |
+
|
| 97 |
+
# Also make a combined overview figure
|
| 98 |
+
fig, axes = plt.subplots(3, 3, figsize=(20, 14))
|
| 99 |
+
overview_metrics = [
|
| 100 |
+
("loss", "Loss"),
|
| 101 |
+
("reward", "Reward"),
|
| 102 |
+
("kl", "KL Divergence"),
|
| 103 |
+
("entropy", "Entropy"),
|
| 104 |
+
("grad_norm", "Grad Norm"),
|
| 105 |
+
("completions/mean_length", "Completion Mean Length"),
|
| 106 |
+
("frac_reward_zero_std", "Frac Reward Zero Std"),
|
| 107 |
+
("completions/clipped_ratio", "Clipped Ratio"),
|
| 108 |
+
("step_time", "Step Time (s)"),
|
| 109 |
+
]
|
| 110 |
+
for ax, (key, title) in zip(axes.flat, overview_metrics):
|
| 111 |
+
vals = [r.get(key, float('nan')) for r in records]
|
| 112 |
+
ax.plot(steps, vals, alpha=0.3, linewidth=0.8, color='steelblue')
|
| 113 |
+
s = smooth(np.array(vals))
|
| 114 |
+
offset = (len(vals) - len(s)) // 2
|
| 115 |
+
ax.plot(steps[offset:offset+len(s)], s, linewidth=2, color='orangered')
|
| 116 |
+
ax.set_title(title, fontsize=12)
|
| 117 |
+
ax.set_xlabel("Step", fontsize=9)
|
| 118 |
+
ax.grid(True, alpha=0.3)
|
| 119 |
+
|
| 120 |
+
fig.suptitle("GRPO Training Overview", fontsize=16, fontweight='bold')
|
| 121 |
+
plt.tight_layout(rect=[0, 0, 1, 0.96])
|
| 122 |
+
path = os.path.join(SAVE_DIR, "overview.png")
|
| 123 |
+
fig.savefig(path, dpi=150)
|
| 124 |
+
plt.close(fig)
|
| 125 |
+
print(f"Saved: {path}")
|
| 126 |
+
|
| 127 |
+
print(f"\nAll plots saved to: {SAVE_DIR}")
|
ICL/RL/plots/clip_ratio_high.png
ADDED
|
ICL/RL/plots/clip_ratio_low.png
ADDED
|
ICL/RL/plots/clipped_ratio.png
ADDED
|
ICL/RL/plots/completion_length.png
ADDED
|
ICL/RL/plots/entropy.png
ADDED
|
ICL/RL/plots/frac_reward_zero_std.png
ADDED
|
ICL/RL/plots/grad_norm.png
ADDED
|
ICL/RL/plots/learning_rate.png
ADDED
|
ICL/RL/quickstart.sh
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Quick start script for GRPO training
|
| 3 |
+
|
| 4 |
+
echo "=================================="
|
| 5 |
+
echo "GRPO Training Quick Start"
|
| 6 |
+
echo "=================================="
|
| 7 |
+
echo ""
|
| 8 |
+
|
| 9 |
+
# Check if virtual environment exists
|
| 10 |
+
if [ ! -d "venv" ]; then
|
| 11 |
+
echo "Creating virtual environment..."
|
| 12 |
+
python3 -m venv venv
|
| 13 |
+
fi
|
| 14 |
+
|
| 15 |
+
# Activate virtual environment
|
| 16 |
+
echo "Activating virtual environment..."
|
| 17 |
+
source venv/bin/activate
|
| 18 |
+
|
| 19 |
+
# Install dependencies
|
| 20 |
+
echo "Installing dependencies..."
|
| 21 |
+
pip install -r requirements.txt
|
| 22 |
+
|
| 23 |
+
echo ""
|
| 24 |
+
echo "=================================="
|
| 25 |
+
echo "Setup complete!"
|
| 26 |
+
echo "=================================="
|
| 27 |
+
echo ""
|
| 28 |
+
echo "Next steps:"
|
| 29 |
+
echo " 1. Test reward functions: python test_rewards.py"
|
| 30 |
+
echo " 2. Visualize rewards: python visualize_rewards.py"
|
| 31 |
+
echo " 3. Start training: python train_grpo.py"
|
| 32 |
+
echo ""
|
| 33 |
+
echo "To monitor training:"
|
| 34 |
+
echo " - Check wandb dashboard (if enabled)"
|
| 35 |
+
echo " - View logs in output/ directory"
|
| 36 |
+
echo ""
|
ICL/RL/requirements.txt
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=2.0.0
|
| 2 |
+
transformers>=4.40.0
|
| 3 |
+
trl>=0.12.0
|
| 4 |
+
datasets>=2.14.0
|
| 5 |
+
accelerate>=0.27.0
|
| 6 |
+
peft>=0.10.0
|
| 7 |
+
bitsandbytes>=0.43.0
|
| 8 |
+
sentencepiece>=0.2.0
|
| 9 |
+
protobuf>=3.20.0
|
| 10 |
+
pillow>=10.0.0
|
| 11 |
+
numpy>=1.24.0
|
| 12 |
+
scikit-learn>=1.3.0
|
| 13 |
+
tqdm>=4.65.0
|
| 14 |
+
wandb>=0.15.0
|
| 15 |
+
matplotlib>=3.7.0
|
| 16 |
+
pyyaml>=6.0
|
ICL/RL/reward_functions.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Reward functions for GRPO training.
|
| 3 |
+
Based on the plan: R_total = R_outcome + Gate(IsCorrect) × R_rel + R_penalty
|
| 4 |
+
"""
|
| 5 |
+
import re
|
| 6 |
+
from typing import List, Dict, Any
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def scale_siglip_score(raw_score: float, min_th: float = 0.33, max_th: float = 0.97) -> float:
|
| 11 |
+
"""
|
| 12 |
+
Scale SigLIP similarity score to [0, 1] range.
|
| 13 |
+
Based on data analysis: P05=0.33, P95=0.97 for same-subdir pairs.
|
| 14 |
+
"""
|
| 15 |
+
scaled = (raw_score - min_th) / (max_th - min_th)
|
| 16 |
+
return max(0.0, min(1.0, scaled))
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def compute_f1_score(prediction: str, ground_truth: str) -> float:
|
| 20 |
+
"""Compute F1 score between prediction and ground truth."""
|
| 21 |
+
pred_tokens = set(prediction.lower().split())
|
| 22 |
+
gt_tokens = set(ground_truth.lower().split())
|
| 23 |
+
|
| 24 |
+
if len(pred_tokens) == 0 or len(gt_tokens) == 0:
|
| 25 |
+
return 0.0
|
| 26 |
+
|
| 27 |
+
common = pred_tokens & gt_tokens
|
| 28 |
+
if len(common) == 0:
|
| 29 |
+
return 0.0
|
| 30 |
+
|
| 31 |
+
precision = len(common) / len(pred_tokens)
|
| 32 |
+
recall = len(common) / len(gt_tokens)
|
| 33 |
+
f1 = 2 * precision * recall / (precision + recall)
|
| 34 |
+
return f1
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def extract_answer(text: str) -> str:
|
| 38 |
+
"""Extract answer from model output after <ANS> token."""
|
| 39 |
+
if "<ANS>" in text:
|
| 40 |
+
return text.split("<ANS>")[-1].strip()
|
| 41 |
+
return text.strip()
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def count_retrieval_steps(trajectory: str) -> int:
|
| 45 |
+
"""Count number of <RET> tokens in trajectory."""
|
| 46 |
+
return trajectory.count("<RET>")
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def compute_reward(
|
| 50 |
+
trajectory: str,
|
| 51 |
+
ground_truth: str,
|
| 52 |
+
siglip_scores: List[float],
|
| 53 |
+
max_turns: int = 3,
|
| 54 |
+
timeout: bool = False
|
| 55 |
+
) -> Dict[str, float]:
|
| 56 |
+
"""
|
| 57 |
+
Compute total reward for a trajectory.
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
trajectory: Full model output with <RET> and <ANS> tokens
|
| 61 |
+
ground_truth: Ground truth answer
|
| 62 |
+
siglip_scores: List of SigLIP similarity scores for each retrieval
|
| 63 |
+
max_turns: Maximum allowed retrieval turns
|
| 64 |
+
timeout: Whether trajectory hit max turns without <ANS>
|
| 65 |
+
|
| 66 |
+
Returns:
|
| 67 |
+
Dictionary with reward components and total
|
| 68 |
+
"""
|
| 69 |
+
# Extract final answer
|
| 70 |
+
prediction = extract_answer(trajectory)
|
| 71 |
+
|
| 72 |
+
# Count retrieval steps
|
| 73 |
+
num_steps = count_retrieval_steps(trajectory)
|
| 74 |
+
|
| 75 |
+
# R_outcome: continuous outcome reward (dense signal)
|
| 76 |
+
# Map f1 ∈ [0, 1] to r_outcome ∈ [-1, 1]
|
| 77 |
+
f1 = compute_f1_score(prediction, ground_truth)
|
| 78 |
+
r_outcome = 2.0 * f1 - 1.0
|
| 79 |
+
|
| 80 |
+
# Keep a softer correctness flag for logging only (no longer gates reward)
|
| 81 |
+
is_correct = f1 > 0.3
|
| 82 |
+
|
| 83 |
+
# R_penalty: Step-based penalty (non-linear)
|
| 84 |
+
if timeout:
|
| 85 |
+
# Death penalty for infinite loop (softened to encourage exploration early)
|
| 86 |
+
r_penalty = -1.5
|
| 87 |
+
else:
|
| 88 |
+
# Lighter penalty encourages exploration/retrieval early
|
| 89 |
+
penalty_map = {0: 0.0, 1: -0.05, 2: -0.15, 3: -0.3}
|
| 90 |
+
r_penalty = penalty_map.get(num_steps, -0.3)
|
| 91 |
+
|
| 92 |
+
# R_rel: Retrieval relevance (decoupled; always provides signal if retrieval happened)
|
| 93 |
+
r_rel = 0.0
|
| 94 |
+
if len(siglip_scores) > 0:
|
| 95 |
+
# Average scaled SigLIP scores across all retrievals
|
| 96 |
+
scaled_scores = [scale_siglip_score(score) for score in siglip_scores]
|
| 97 |
+
avg_rel = sum(scaled_scores) / len(scaled_scores)
|
| 98 |
+
# Soft-gate by f1 but keep a minimum signal; downweight to avoid dominating r_outcome
|
| 99 |
+
r_rel = avg_rel * max(0.3, f1) * 0.5
|
| 100 |
+
|
| 101 |
+
# Total reward
|
| 102 |
+
r_total = r_outcome + r_rel + r_penalty
|
| 103 |
+
|
| 104 |
+
return {
|
| 105 |
+
"r_outcome": r_outcome,
|
| 106 |
+
"r_penalty": r_penalty,
|
| 107 |
+
"r_rel": r_rel,
|
| 108 |
+
"r_total": r_total,
|
| 109 |
+
"f1_score": f1,
|
| 110 |
+
"num_steps": num_steps,
|
| 111 |
+
"is_correct": is_correct,
|
| 112 |
+
"timeout": timeout
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def batch_compute_rewards(
|
| 117 |
+
trajectories: List[str],
|
| 118 |
+
ground_truths: List[str],
|
| 119 |
+
siglip_scores_list: List[List[float]],
|
| 120 |
+
max_turns: int = 3,
|
| 121 |
+
timeouts: List[bool] = None
|
| 122 |
+
) -> List[float]:
|
| 123 |
+
"""
|
| 124 |
+
Compute rewards for a batch of trajectories.
|
| 125 |
+
Returns list of total rewards for GRPO.
|
| 126 |
+
"""
|
| 127 |
+
if timeouts is None:
|
| 128 |
+
timeouts = [False] * len(trajectories)
|
| 129 |
+
|
| 130 |
+
rewards = []
|
| 131 |
+
for traj, gt, scores, timeout in zip(trajectories, ground_truths, siglip_scores_list, timeouts):
|
| 132 |
+
reward_dict = compute_reward(traj, gt, scores, max_turns, timeout)
|
| 133 |
+
rewards.append(reward_dict["r_total"])
|
| 134 |
+
|
| 135 |
+
return rewards
|
ICL/RL/run_grpo.sh
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
cd /workspace/xiaobin/RL
|
| 3 |
+
|
| 4 |
+
LOG_FILE="train_grpo_$(date +%Y%m%d_%H%M%S).log"
|
| 5 |
+
KEY_LOG="key_metrics_$(date +%Y%m%d_%H%M%S).log"
|
| 6 |
+
|
| 7 |
+
echo "Full log: $LOG_FILE"
|
| 8 |
+
echo "Key metrics log: $KEY_LOG"
|
| 9 |
+
|
| 10 |
+
# Run training, full output to LOG_FILE
|
| 11 |
+
nohup python train_grpo.py > "$LOG_FILE" 2>&1 &
|
| 12 |
+
PID=$!
|
| 13 |
+
echo "Training started with PID: $PID"
|
| 14 |
+
echo "$PID" > train_pid.txt
|
| 15 |
+
|
| 16 |
+
# Background process to extract key metrics every 30s
|
| 17 |
+
nohup bash -c "
|
| 18 |
+
while kill -0 $PID 2>/dev/null; do
|
| 19 |
+
grep -E '(loss|reward|entropy|clip_ratio|grad_norm|epoch|frac_reward)' \"$LOG_FILE\" | tail -200 > \"$KEY_LOG\"
|
| 20 |
+
sleep 30
|
| 21 |
+
done
|
| 22 |
+
# Final extraction
|
| 23 |
+
grep -E '(loss|reward|entropy|clip_ratio|grad_norm|epoch|frac_reward)' \"$LOG_FILE\" > \"$KEY_LOG\"
|
| 24 |
+
echo 'Training finished.' >> \"$KEY_LOG\"
|
| 25 |
+
" > /dev/null 2>&1 &
|
| 26 |
+
|
| 27 |
+
echo "Metric watcher started."
|
| 28 |
+
echo "To monitor: tail -f /workspace/xiaobin/RL/$LOG_FILE"
|
| 29 |
+
echo "Key metrics: cat /workspace/xiaobin/RL/$KEY_LOG"
|
ICL/RL/siglip_analysis/score_distribution.png
ADDED
|
ICL/RL/test_device_config.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Test script to verify multi-GPU device configuration is correct.
|
| 3 |
+
"""
|
| 4 |
+
import torch
|
| 5 |
+
from transformers import AutoModel, AutoProcessor
|
| 6 |
+
|
| 7 |
+
def test_device_config():
|
| 8 |
+
print("="*80)
|
| 9 |
+
print("TESTING DEVICE CONFIGURATION")
|
| 10 |
+
print("="*80)
|
| 11 |
+
|
| 12 |
+
# Check CUDA availability
|
| 13 |
+
print(f"\nCUDA available: {torch.cuda.is_available()}")
|
| 14 |
+
print(f"GPU count: {torch.cuda.device_count()}")
|
| 15 |
+
|
| 16 |
+
if torch.cuda.device_count() > 0:
|
| 17 |
+
for i in range(torch.cuda.device_count()):
|
| 18 |
+
print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
|
| 19 |
+
|
| 20 |
+
# Load SigLIP with device_map="auto"
|
| 21 |
+
print("\n" + "="*80)
|
| 22 |
+
print("LOADING SIGLIP MODEL")
|
| 23 |
+
print("="*80)
|
| 24 |
+
|
| 25 |
+
siglip_model_name = "/workspace/siglip2-so400m-patch16-naflex"
|
| 26 |
+
siglip_model = AutoModel.from_pretrained(siglip_model_name, device_map="auto")
|
| 27 |
+
siglip_processor = AutoProcessor.from_pretrained(siglip_model_name)
|
| 28 |
+
|
| 29 |
+
# Check where model parameters are
|
| 30 |
+
print("\nSigLIP model parameter devices:")
|
| 31 |
+
for name, param in list(siglip_model.named_parameters())[:5]:
|
| 32 |
+
print(f" {name}: {param.device}")
|
| 33 |
+
|
| 34 |
+
# Get the device of the first parameter
|
| 35 |
+
model_device = next(siglip_model.parameters()).device
|
| 36 |
+
print(f"\nFirst parameter device: {model_device}")
|
| 37 |
+
|
| 38 |
+
# Test text encoding
|
| 39 |
+
print("\n" + "="*80)
|
| 40 |
+
print("TESTING TEXT ENCODING")
|
| 41 |
+
print("="*80)
|
| 42 |
+
|
| 43 |
+
text_inputs = siglip_processor(
|
| 44 |
+
text=["a photo of a cat"],
|
| 45 |
+
return_tensors="pt",
|
| 46 |
+
padding=True
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
print(f"Text inputs device before moving: {text_inputs['input_ids'].device}")
|
| 50 |
+
|
| 51 |
+
# Move to model device
|
| 52 |
+
text_inputs = {k: v.to(model_device) for k, v in text_inputs.items()}
|
| 53 |
+
print(f"Text inputs device after moving: {text_inputs['input_ids'].device}")
|
| 54 |
+
|
| 55 |
+
with torch.no_grad():
|
| 56 |
+
text_embeds = siglip_model.get_text_features(**text_inputs)
|
| 57 |
+
print(f"Text embeddings device: {text_embeds.device}")
|
| 58 |
+
print(f"Text embeddings shape: {text_embeds.shape}")
|
| 59 |
+
|
| 60 |
+
# Test image encoding
|
| 61 |
+
print("\n" + "="*80)
|
| 62 |
+
print("TESTING IMAGE ENCODING")
|
| 63 |
+
print("="*80)
|
| 64 |
+
|
| 65 |
+
from PIL import Image
|
| 66 |
+
import numpy as np
|
| 67 |
+
|
| 68 |
+
# Create a dummy image
|
| 69 |
+
dummy_image = Image.fromarray(np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8))
|
| 70 |
+
|
| 71 |
+
image_inputs = siglip_processor(
|
| 72 |
+
images=[dummy_image],
|
| 73 |
+
return_tensors="pt",
|
| 74 |
+
padding=True
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
print(f"Image inputs device before moving: {image_inputs['pixel_values'].device}")
|
| 78 |
+
|
| 79 |
+
# Move to model device
|
| 80 |
+
image_inputs = {k: v.to(model_device) for k, v in image_inputs.items()}
|
| 81 |
+
print(f"Image inputs device after moving: {image_inputs['pixel_values'].device}")
|
| 82 |
+
|
| 83 |
+
with torch.no_grad():
|
| 84 |
+
image_embeds = siglip_model.get_image_features(**image_inputs)
|
| 85 |
+
print(f"Image embeddings device: {image_embeds.device}")
|
| 86 |
+
print(f"Image embeddings shape: {image_embeds.shape}")
|
| 87 |
+
|
| 88 |
+
# Test similarity computation
|
| 89 |
+
print("\n" + "="*80)
|
| 90 |
+
print("TESTING SIMILARITY COMPUTATION")
|
| 91 |
+
print("="*80)
|
| 92 |
+
|
| 93 |
+
# Create dummy embeddings on CPU
|
| 94 |
+
dummy_pool_embeds = torch.randn(100, text_embeds.shape[-1])
|
| 95 |
+
print(f"Pool embeddings device (CPU): {dummy_pool_embeds.device}")
|
| 96 |
+
|
| 97 |
+
# Move to same device as text_embeds
|
| 98 |
+
dummy_pool_embeds = dummy_pool_embeds.to(text_embeds.device)
|
| 99 |
+
print(f"Pool embeddings device after moving: {dummy_pool_embeds.device}")
|
| 100 |
+
|
| 101 |
+
# Compute similarity
|
| 102 |
+
similarities = (text_embeds @ dummy_pool_embeds.T).squeeze(0)
|
| 103 |
+
print(f"Similarities device: {similarities.device}")
|
| 104 |
+
print(f"Similarities shape: {similarities.shape}")
|
| 105 |
+
|
| 106 |
+
print("\n" + "="*80)
|
| 107 |
+
print("ALL TESTS PASSED!")
|
| 108 |
+
print("="*80)
|
| 109 |
+
|
| 110 |
+
if __name__ == "__main__":
|
| 111 |
+
test_device_config()
|
ICL/RL/train_grpo.py
ADDED
|
@@ -0,0 +1,389 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Main GRPO training script for retrieval-augmented VQA.
|
| 3 |
+
Based on TRL's GRPOTrainer with custom reward functions and environment.
|
| 4 |
+
"""
|
| 5 |
+
import os
|
| 6 |
+
os.environ["PYTHONIOENCODING"] = "utf-8"
|
| 7 |
+
os.environ["LC_ALL"] = "C.UTF-8"
|
| 8 |
+
import torch
|
| 9 |
+
from transformers import AutoTokenizer, AutoProcessor
|
| 10 |
+
from trl import GRPOTrainer, GRPOConfig
|
| 11 |
+
from datasets import Dataset
|
| 12 |
+
import wandb
|
| 13 |
+
from typing import List, Dict, Any
|
| 14 |
+
|
| 15 |
+
from reward_functions import batch_compute_rewards
|
| 16 |
+
from environment import RetrievalEnvironment
|
| 17 |
+
from data_utils import (
|
| 18 |
+
load_m3it_splits,
|
| 19 |
+
create_balanced_dataset,
|
| 20 |
+
build_candidate_pools,
|
| 21 |
+
format_prompt
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def create_reward_function(environment: RetrievalEnvironment):
|
| 26 |
+
"""
|
| 27 |
+
Create a reward function closure that captures the environment.
|
| 28 |
+
This function will be passed to GRPOTrainer as reward_funcs.
|
| 29 |
+
"""
|
| 30 |
+
def parse_retrieval_queries(completion: str) -> List[str]:
|
| 31 |
+
"""
|
| 32 |
+
Parse retrieval queries from completion text.
|
| 33 |
+
Extracts text between <RET> and the next <ANS> or <RET>.
|
| 34 |
+
"""
|
| 35 |
+
queries = []
|
| 36 |
+
parts = completion.split("<RET>")
|
| 37 |
+
|
| 38 |
+
for i in range(1, len(parts)): # Skip first part (before first <RET>)
|
| 39 |
+
query_part = parts[i]
|
| 40 |
+
# Extract until next special token or end
|
| 41 |
+
if "<ANS>" in query_part:
|
| 42 |
+
query = query_part.split("<ANS>")[0].strip()
|
| 43 |
+
elif "<RET>" in query_part:
|
| 44 |
+
query = query_part.split("<RET>")[0].strip()
|
| 45 |
+
else:
|
| 46 |
+
query = query_part.strip()
|
| 47 |
+
|
| 48 |
+
if query:
|
| 49 |
+
queries.append(query)
|
| 50 |
+
|
| 51 |
+
return queries
|
| 52 |
+
|
| 53 |
+
def reward_function(completions: List[str], **kwargs) -> List[float]:
|
| 54 |
+
"""
|
| 55 |
+
Custom reward function for GRPO.
|
| 56 |
+
|
| 57 |
+
For each completion:
|
| 58 |
+
1. Parse <RET> queries from the trajectory
|
| 59 |
+
2. Execute retrieval to get SigLIP scores
|
| 60 |
+
3. Compute reward using the full reward function
|
| 61 |
+
"""
|
| 62 |
+
# Extract ground truth answers from kwargs
|
| 63 |
+
# GRPO passes dataset columns through kwargs
|
| 64 |
+
answers = kwargs.get("answer", [])
|
| 65 |
+
images = kwargs.get("image", [])
|
| 66 |
+
datasets = kwargs.get("dataset", [])
|
| 67 |
+
ids = kwargs.get("id", [])
|
| 68 |
+
|
| 69 |
+
if not answers:
|
| 70 |
+
# Fallback: return zero rewards if no ground truth
|
| 71 |
+
print("WARNING: No ground truth answers found in kwargs")
|
| 72 |
+
return [0.0] * len(completions)
|
| 73 |
+
|
| 74 |
+
rewards = []
|
| 75 |
+
siglip_scores_list = []
|
| 76 |
+
timeouts = []
|
| 77 |
+
|
| 78 |
+
for completion, answer, image, dataset_name, sample_id in zip(
|
| 79 |
+
completions, answers, images, datasets, ids
|
| 80 |
+
):
|
| 81 |
+
# Convert completion to string
|
| 82 |
+
# For conversational format, completion is a list like [{"role": "assistant", "content": "..."}]
|
| 83 |
+
if isinstance(completion, list) and len(completion) > 0:
|
| 84 |
+
if isinstance(completion[0], dict) and "content" in completion[0]:
|
| 85 |
+
completion_text = completion[0]["content"]
|
| 86 |
+
else:
|
| 87 |
+
completion_text = str(completion)
|
| 88 |
+
else:
|
| 89 |
+
completion_text = str(completion)
|
| 90 |
+
|
| 91 |
+
# Parse retrieval queries from completion
|
| 92 |
+
retrieval_queries = parse_retrieval_queries(completion_text)
|
| 93 |
+
|
| 94 |
+
# Execute retrieval for each query to get SigLIP scores
|
| 95 |
+
siglip_scores = []
|
| 96 |
+
if len(retrieval_queries) > 0:
|
| 97 |
+
for query_text in retrieval_queries:
|
| 98 |
+
try:
|
| 99 |
+
# Retrieve top-1 similar image using text query
|
| 100 |
+
retrieved_candidate, _ = environment.retrieve_top1(
|
| 101 |
+
query_text=query_text,
|
| 102 |
+
dataset_name=dataset_name,
|
| 103 |
+
exclude_id=sample_id
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
# Compute image-to-image similarity for reward
|
| 107 |
+
img_similarity = environment.compute_image_similarity(
|
| 108 |
+
query_image=image,
|
| 109 |
+
retrieved_image=retrieved_candidate["image"]
|
| 110 |
+
)
|
| 111 |
+
siglip_scores.append(img_similarity)
|
| 112 |
+
except Exception as e:
|
| 113 |
+
print(f"Retrieval error for query '{query_text}': {e}")
|
| 114 |
+
siglip_scores.append(0.0)
|
| 115 |
+
|
| 116 |
+
# Check for timeout (exceeded max turns without <ANS>)
|
| 117 |
+
num_rets = completion_text.count("<RET>")
|
| 118 |
+
has_answer = "<ANS>" in completion_text
|
| 119 |
+
timeout = (num_rets >= environment.max_turns and not has_answer)
|
| 120 |
+
|
| 121 |
+
siglip_scores_list.append(siglip_scores)
|
| 122 |
+
timeouts.append(timeout)
|
| 123 |
+
|
| 124 |
+
# Compute rewards using our custom function
|
| 125 |
+
# Convert completions to strings for batch_compute_rewards
|
| 126 |
+
completion_texts = []
|
| 127 |
+
for comp in completions:
|
| 128 |
+
if isinstance(comp, list) and len(comp) > 0:
|
| 129 |
+
if isinstance(comp[0], dict) and "content" in comp[0]:
|
| 130 |
+
completion_texts.append(comp[0]["content"])
|
| 131 |
+
else:
|
| 132 |
+
completion_texts.append(str(comp))
|
| 133 |
+
else:
|
| 134 |
+
completion_texts.append(str(comp))
|
| 135 |
+
|
| 136 |
+
rewards = batch_compute_rewards(
|
| 137 |
+
trajectories=completion_texts,
|
| 138 |
+
ground_truths=answers,
|
| 139 |
+
siglip_scores_list=siglip_scores_list,
|
| 140 |
+
max_turns=environment.max_turns,
|
| 141 |
+
timeouts=timeouts
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
return rewards
|
| 145 |
+
|
| 146 |
+
return reward_function
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def load_models(
|
| 150 |
+
model_path: str,
|
| 151 |
+
siglip_model_name: str = "/workspace/siglip2-so400m-patch16-naflex",
|
| 152 |
+
device: str = "cuda"
|
| 153 |
+
):
|
| 154 |
+
"""Load VLM and SigLIP models."""
|
| 155 |
+
print(f"Loading VLM from {model_path}...")
|
| 156 |
+
|
| 157 |
+
# Import the specific Qwen3VL generation class
|
| 158 |
+
from transformers.models.qwen3_vl import Qwen3VLForConditionalGeneration
|
| 159 |
+
from transformers import AutoModel
|
| 160 |
+
|
| 161 |
+
model = Qwen3VLForConditionalGeneration.from_pretrained(
|
| 162 |
+
model_path,
|
| 163 |
+
dtype=torch.bfloat16,
|
| 164 |
+
device_map="auto",
|
| 165 |
+
trust_remote_code=True,
|
| 166 |
+
attn_implementation="flash_attention_2"
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
# Use processor instead of tokenizer for vision-language models
|
| 170 |
+
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
|
| 171 |
+
|
| 172 |
+
print(f"Loading SigLIP from {siglip_model_name}...")
|
| 173 |
+
siglip_model = AutoModel.from_pretrained(siglip_model_name, device_map="auto")
|
| 174 |
+
siglip_processor = AutoProcessor.from_pretrained(siglip_model_name)
|
| 175 |
+
|
| 176 |
+
return model, processor, siglip_model, siglip_processor
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def prepare_dataset_for_grpo(samples: List[Dict]) -> Dataset:
|
| 180 |
+
"""
|
| 181 |
+
Convert samples to HuggingFace Dataset format for GRPO.
|
| 182 |
+
GRPO expects simple chat messages (strings), and will handle image insertion automatically.
|
| 183 |
+
"""
|
| 184 |
+
dataset_dict = {
|
| 185 |
+
"prompt": [],
|
| 186 |
+
"image": [],
|
| 187 |
+
"answer": [],
|
| 188 |
+
"dataset": [],
|
| 189 |
+
"id": [],
|
| 190 |
+
"needs_retrieval": []
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
for sample in samples:
|
| 194 |
+
# Create simple chat messages - GRPO will handle image insertion
|
| 195 |
+
messages = [
|
| 196 |
+
{
|
| 197 |
+
"role": "system",
|
| 198 |
+
"content": (
|
| 199 |
+
"You are a visual question answering assistant. "
|
| 200 |
+
"You can either answer directly using <ANS> or retrieve similar images using <RET>.\n"
|
| 201 |
+
"- Use <RET> followed by a description when you need to see similar examples.\n"
|
| 202 |
+
"- Use <ANS> followed by your answer when you're ready to answer."
|
| 203 |
+
)
|
| 204 |
+
},
|
| 205 |
+
{
|
| 206 |
+
"role": "user",
|
| 207 |
+
"content": f"Question: {sample['question']}"
|
| 208 |
+
}
|
| 209 |
+
]
|
| 210 |
+
|
| 211 |
+
dataset_dict["prompt"].append(messages)
|
| 212 |
+
dataset_dict["image"].append(sample["image"])
|
| 213 |
+
dataset_dict["answer"].append(sample["answer"])
|
| 214 |
+
dataset_dict["dataset"].append(sample["dataset"])
|
| 215 |
+
dataset_dict["id"].append(sample["id"])
|
| 216 |
+
dataset_dict["needs_retrieval"].append(sample["needs_retrieval"])
|
| 217 |
+
|
| 218 |
+
# Create dataset without using from_dict to avoid PyArrow issues
|
| 219 |
+
from datasets import Dataset as HFDataset
|
| 220 |
+
|
| 221 |
+
# Convert to list of dicts format
|
| 222 |
+
data_list = []
|
| 223 |
+
for i in range(len(dataset_dict["prompt"])):
|
| 224 |
+
data_list.append({
|
| 225 |
+
"prompt": dataset_dict["prompt"][i],
|
| 226 |
+
"image": dataset_dict["image"][i],
|
| 227 |
+
"answer": dataset_dict["answer"][i],
|
| 228 |
+
"dataset": dataset_dict["dataset"][i],
|
| 229 |
+
"id": dataset_dict["id"][i],
|
| 230 |
+
"needs_retrieval": dataset_dict["needs_retrieval"][i]
|
| 231 |
+
})
|
| 232 |
+
|
| 233 |
+
return HFDataset.from_list(data_list)
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def main():
|
| 237 |
+
# Configuration
|
| 238 |
+
MODEL_PATH = "/workspace/xiaobin/SFT_model/hf_qwen3vl_siglip_vqa_iter_0000881"
|
| 239 |
+
OUTPUT_DIR = "/workspace/xiaobin/RL/output"
|
| 240 |
+
SIGLIP_MODEL = "/workspace/siglip2-so400m-patch16-naflex"
|
| 241 |
+
RL_DATASET_PATH = "/workspace/xiaobin/RL_data/rl_train.jsonl" # Pre-built dataset
|
| 242 |
+
|
| 243 |
+
# Training hyperparameters - Optimized for speed with more GPU memory
|
| 244 |
+
LEARNING_RATE = 5e-6 # Reduced from 1e-5 to slow down collapse
|
| 245 |
+
BATCH_SIZE = 64 # Increased from 32
|
| 246 |
+
NUM_GENERATIONS = 4 # 64/4=16 unique prompts per step
|
| 247 |
+
MAX_TURNS = 3
|
| 248 |
+
NUM_EPOCHS = 3
|
| 249 |
+
MAX_PROMPT_LENGTH = 512
|
| 250 |
+
MAX_COMPLETION_LENGTH = 256
|
| 251 |
+
|
| 252 |
+
# Data parameters
|
| 253 |
+
MAX_SAMPLES = None # Use all 10000 samples
|
| 254 |
+
USE_WANDB = False # Disabled wandb
|
| 255 |
+
|
| 256 |
+
# Initialize wandb
|
| 257 |
+
if USE_WANDB:
|
| 258 |
+
wandb.init(
|
| 259 |
+
project="grpo-retrieval-vqa",
|
| 260 |
+
config={
|
| 261 |
+
"model": MODEL_PATH,
|
| 262 |
+
"learning_rate": LEARNING_RATE,
|
| 263 |
+
"batch_size": BATCH_SIZE,
|
| 264 |
+
"num_generations": NUM_GENERATIONS,
|
| 265 |
+
"max_turns": MAX_TURNS,
|
| 266 |
+
}
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
# Load models
|
| 270 |
+
model, processor, siglip_model, siglip_processor = load_models(
|
| 271 |
+
MODEL_PATH,
|
| 272 |
+
SIGLIP_MODEL
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
# Ensure model has generation_config
|
| 276 |
+
from transformers import GenerationConfig
|
| 277 |
+
if not hasattr(model, 'generation_config') or model.generation_config is None:
|
| 278 |
+
model.generation_config = GenerationConfig.from_model_config(model.config)
|
| 279 |
+
print("Created default generation_config for model")
|
| 280 |
+
|
| 281 |
+
# Load and prepare data
|
| 282 |
+
print("\n" + "="*80)
|
| 283 |
+
print("LOADING DATA")
|
| 284 |
+
print("="*80)
|
| 285 |
+
|
| 286 |
+
import json
|
| 287 |
+
all_samples = []
|
| 288 |
+
with open(RL_DATASET_PATH, 'r', encoding='utf-8') as f:
|
| 289 |
+
for line in f:
|
| 290 |
+
sample = json.loads(line)
|
| 291 |
+
# Convert to expected format
|
| 292 |
+
all_samples.append({
|
| 293 |
+
'id': sample['id'],
|
| 294 |
+
'image': sample['image'],
|
| 295 |
+
'question': sample['question'],
|
| 296 |
+
'answer': sample['answer'],
|
| 297 |
+
'dataset': sample['subdir'],
|
| 298 |
+
'needs_retrieval': sample['category'] == 'positive'
|
| 299 |
+
})
|
| 300 |
+
|
| 301 |
+
# Limit samples if specified
|
| 302 |
+
if MAX_SAMPLES is not None and len(all_samples) > MAX_SAMPLES:
|
| 303 |
+
import random
|
| 304 |
+
random.seed(42)
|
| 305 |
+
all_samples = random.sample(all_samples, MAX_SAMPLES)
|
| 306 |
+
|
| 307 |
+
print(f"Loaded {len(all_samples)} samples from {RL_DATASET_PATH}")
|
| 308 |
+
|
| 309 |
+
# Build candidate pools for retrieval
|
| 310 |
+
positive_samples = [s for s in all_samples if s['needs_retrieval']]
|
| 311 |
+
negative_samples = [s for s in all_samples if not s['needs_retrieval']]
|
| 312 |
+
candidate_pools = build_candidate_pools(positive_samples, negative_samples)
|
| 313 |
+
|
| 314 |
+
# Convert to HF Dataset
|
| 315 |
+
train_dataset = prepare_dataset_for_grpo(all_samples)
|
| 316 |
+
|
| 317 |
+
# Initialize retrieval environment
|
| 318 |
+
print("\n" + "="*80)
|
| 319 |
+
print("INITIALIZING ENVIRONMENT")
|
| 320 |
+
print("="*80)
|
| 321 |
+
environment = RetrievalEnvironment(
|
| 322 |
+
siglip_model=siglip_model,
|
| 323 |
+
siglip_processor=siglip_processor,
|
| 324 |
+
candidate_pools=candidate_pools,
|
| 325 |
+
max_turns=MAX_TURNS,
|
| 326 |
+
device="cuda"
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
# Configure GRPO training
|
| 330 |
+
print("\n" + "="*80)
|
| 331 |
+
print("CONFIGURING GRPO TRAINER")
|
| 332 |
+
print("="*80)
|
| 333 |
+
training_args = GRPOConfig(
|
| 334 |
+
output_dir=OUTPUT_DIR,
|
| 335 |
+
learning_rate=LEARNING_RATE,
|
| 336 |
+
per_device_train_batch_size=BATCH_SIZE,
|
| 337 |
+
num_train_epochs=NUM_EPOCHS,
|
| 338 |
+
num_generations=NUM_GENERATIONS,
|
| 339 |
+
generation_batch_size=NUM_GENERATIONS * 16, # 4*16=64
|
| 340 |
+
max_completion_length=MAX_COMPLETION_LENGTH,
|
| 341 |
+
logging_steps=10,
|
| 342 |
+
save_steps=100,
|
| 343 |
+
save_total_limit=3,
|
| 344 |
+
bf16=True,
|
| 345 |
+
gradient_accumulation_steps=1,
|
| 346 |
+
warmup_steps=50,
|
| 347 |
+
report_to="wandb" if USE_WANDB else "none",
|
| 348 |
+
gradient_checkpointing=True,
|
| 349 |
+
temperature=1.4, # Higher temp for exploration
|
| 350 |
+
top_p=0.95, # Wider sampling
|
| 351 |
+
max_grad_norm=1.0, # Explicit gradient clipping
|
| 352 |
+
lr_scheduler_type="constant_with_warmup", # Prevent lr decay stall
|
| 353 |
+
repetition_penalty=1.1, # Discourage identical generations
|
| 354 |
+
beta=0.04, # KL regularization to prevent policy collapse
|
| 355 |
+
epsilon=0.3, # Wider clip range (default 0.2)
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
# Create custom reward function
|
| 359 |
+
reward_func = create_reward_function(environment)
|
| 360 |
+
|
| 361 |
+
# Initialize trainer with custom reward function
|
| 362 |
+
trainer = GRPOTrainer(
|
| 363 |
+
model=model,
|
| 364 |
+
args=training_args,
|
| 365 |
+
train_dataset=train_dataset,
|
| 366 |
+
processing_class=processor,
|
| 367 |
+
reward_funcs=reward_func,
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
# Start training
|
| 371 |
+
print("\n" + "="*80)
|
| 372 |
+
print("STARTING TRAINING")
|
| 373 |
+
print("="*80)
|
| 374 |
+
trainer.train()
|
| 375 |
+
|
| 376 |
+
# Save final model
|
| 377 |
+
print("\n" + "="*80)
|
| 378 |
+
print("SAVING MODEL")
|
| 379 |
+
print("="*80)
|
| 380 |
+
final_output_dir = os.path.join(OUTPUT_DIR, "final_model")
|
| 381 |
+
trainer.save_model(final_output_dir)
|
| 382 |
+
print(f"Model saved to {final_output_dir}")
|
| 383 |
+
|
| 384 |
+
if USE_WANDB:
|
| 385 |
+
wandb.finish()
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
if __name__ == "__main__":
|
| 389 |
+
main()
|
ICL/RL/train_grpo_20260224_133510.log
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
ICL/RL/train_pid.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
2895
|
ICL/RL/trl_source/.pre-commit-config.yaml
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
repos:
|
| 2 |
+
- repo: https://github.com/astral-sh/ruff-pre-commit
|
| 3 |
+
rev: v0.13.3
|
| 4 |
+
hooks:
|
| 5 |
+
- id: ruff-check
|
| 6 |
+
types_or: [ python, pyi ]
|
| 7 |
+
args: [ --fix ]
|
| 8 |
+
- id: ruff-format
|
| 9 |
+
types_or: [ python, pyi ]
|
| 10 |
+
|
| 11 |
+
# - repo: https://github.com/codespell-project/codespell
|
| 12 |
+
# rev: v2.1.0
|
| 13 |
+
# hooks:
|
| 14 |
+
# - id: codespell
|
| 15 |
+
# args:
|
| 16 |
+
# - --ignore-words-list=nd,reacher,thist,ths,magent,ba
|
| 17 |
+
# - --skip=docs/css/termynal.css,docs/js/termynal.js
|
ICL/RL/trl_source/CITATION.cff
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
cff-version: 1.2.0
|
| 2 |
+
title: 'TRL: Transformers Reinforcement Learning'
|
| 3 |
+
message: >-
|
| 4 |
+
If you use this software, please cite it using the
|
| 5 |
+
metadata from this file.
|
| 6 |
+
type: software
|
| 7 |
+
authors:
|
| 8 |
+
- given-names: Leandro
|
| 9 |
+
family-names: von Werra
|
| 10 |
+
- given-names: Younes
|
| 11 |
+
family-names: Belkada
|
| 12 |
+
- given-names: Lewis
|
| 13 |
+
family-names: Tunstall
|
| 14 |
+
- given-names: Edward
|
| 15 |
+
family-names: Beeching
|
| 16 |
+
- given-names: Tristan
|
| 17 |
+
family-names: Thrush
|
| 18 |
+
- given-names: Nathan
|
| 19 |
+
family-names: Lambert
|
| 20 |
+
- given-names: Shengyi
|
| 21 |
+
family-names: Huang
|
| 22 |
+
- given-names: Kashif
|
| 23 |
+
family-names: Rasul
|
| 24 |
+
- given-names: Quentin
|
| 25 |
+
family-names: Gallouédec
|
| 26 |
+
repository-code: 'https://github.com/huggingface/trl'
|
| 27 |
+
abstract: >-
|
| 28 |
+
TRL (Transformers Reinforcement Learning) is an
|
| 29 |
+
open-source toolkit for aligning transformer models via
|
| 30 |
+
post-training. It provides practical, scalable
|
| 31 |
+
implementations of SFT, reward modeling, DPO, and GRPO
|
| 32 |
+
within the Hugging Face ecosystem.
|
| 33 |
+
keywords:
|
| 34 |
+
- transformers
|
| 35 |
+
- reinforcement learning
|
| 36 |
+
- preference optimization
|
| 37 |
+
- language model alignment
|
| 38 |
+
- post-training
|
| 39 |
+
license: Apache-2.0
|
| 40 |
+
version: '0.28'
|
| 41 |
+
date-released: '2020-03-27'
|
ICL/RL/trl_source/CODE_OF_CONDUCT.md
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
# Contributor Covenant Code of Conduct
|
| 3 |
+
|
| 4 |
+
## Our Pledge
|
| 5 |
+
|
| 6 |
+
We as members, contributors, and leaders pledge to make participation in our
|
| 7 |
+
community a harassment-free experience for everyone, regardless of age, body
|
| 8 |
+
size, visible or invisible disability, ethnicity, sex characteristics, gender
|
| 9 |
+
identity and expression, level of experience, education, socio-economic status,
|
| 10 |
+
nationality, personal appearance, race, caste, color, religion, or sexual
|
| 11 |
+
identity and orientation.
|
| 12 |
+
|
| 13 |
+
We pledge to act and interact in ways that contribute to an open, welcoming,
|
| 14 |
+
diverse, inclusive, and healthy community.
|
| 15 |
+
|
| 16 |
+
## Our Standards
|
| 17 |
+
|
| 18 |
+
Examples of behavior that contributes to a positive environment for our
|
| 19 |
+
community include:
|
| 20 |
+
|
| 21 |
+
* Demonstrating empathy and kindness toward other people
|
| 22 |
+
* Being respectful of differing opinions, viewpoints, and experiences
|
| 23 |
+
* Giving and gracefully accepting constructive feedback
|
| 24 |
+
* Accepting responsibility and apologizing to those affected by our mistakes,
|
| 25 |
+
and learning from the experience
|
| 26 |
+
* Focusing on what is best not just for us as individuals, but for the overall
|
| 27 |
+
community
|
| 28 |
+
|
| 29 |
+
Examples of unacceptable behavior include:
|
| 30 |
+
|
| 31 |
+
* The use of sexualized language or imagery, and sexual attention or advances of
|
| 32 |
+
any kind
|
| 33 |
+
* Trolling, insulting or derogatory comments, and personal or political attacks
|
| 34 |
+
* Public or private harassment
|
| 35 |
+
* Publishing others' private information, such as a physical or email address,
|
| 36 |
+
without their explicit permission
|
| 37 |
+
* Other conduct which could reasonably be considered inappropriate in a
|
| 38 |
+
professional setting
|
| 39 |
+
|
| 40 |
+
## Enforcement Responsibilities
|
| 41 |
+
|
| 42 |
+
Community leaders are responsible for clarifying and enforcing our standards of
|
| 43 |
+
acceptable behavior and will take appropriate and fair corrective action in
|
| 44 |
+
response to any behavior that they deem inappropriate, threatening, offensive,
|
| 45 |
+
or harmful.
|
| 46 |
+
|
| 47 |
+
Community leaders have the right and responsibility to remove, edit, or reject
|
| 48 |
+
comments, commits, code, wiki edits, issues, and other contributions that are
|
| 49 |
+
not aligned to this Code of Conduct, and will communicate reasons for moderation
|
| 50 |
+
decisions when appropriate.
|
| 51 |
+
|
| 52 |
+
## Scope
|
| 53 |
+
|
| 54 |
+
This Code of Conduct applies within all community spaces, and also applies when
|
| 55 |
+
an individual is officially representing the community in public spaces.
|
| 56 |
+
Examples of representing our community include using an official e-mail address,
|
| 57 |
+
posting via an official social media account, or acting as an appointed
|
| 58 |
+
representative at an online or offline event.
|
| 59 |
+
|
| 60 |
+
## Enforcement
|
| 61 |
+
|
| 62 |
+
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
| 63 |
+
reported to the community leaders responsible for enforcement at
|
| 64 |
+
feedback@huggingface.co.
|
| 65 |
+
All complaints will be reviewed and investigated promptly and fairly.
|
| 66 |
+
|
| 67 |
+
All community leaders are obligated to respect the privacy and security of the
|
| 68 |
+
reporter of any incident.
|
| 69 |
+
|
| 70 |
+
## Enforcement Guidelines
|
| 71 |
+
|
| 72 |
+
Community leaders will follow these Community Impact Guidelines in determining
|
| 73 |
+
the consequences for any action they deem in violation of this Code of Conduct:
|
| 74 |
+
|
| 75 |
+
### 1. Correction
|
| 76 |
+
|
| 77 |
+
**Community Impact**: Use of inappropriate language or other behavior deemed
|
| 78 |
+
unprofessional or unwelcome in the community.
|
| 79 |
+
|
| 80 |
+
**Consequence**: A private, written warning from community leaders, providing
|
| 81 |
+
clarity around the nature of the violation and an explanation of why the
|
| 82 |
+
behavior was inappropriate. A public apology may be requested.
|
| 83 |
+
|
| 84 |
+
### 2. Warning
|
| 85 |
+
|
| 86 |
+
**Community Impact**: A violation through a single incident or series of
|
| 87 |
+
actions.
|
| 88 |
+
|
| 89 |
+
**Consequence**: A warning with consequences for continued behavior. No
|
| 90 |
+
interaction with the people involved, including unsolicited interaction with
|
| 91 |
+
those enforcing the Code of Conduct, for a specified period of time. This
|
| 92 |
+
includes avoiding interactions in community spaces as well as external channels
|
| 93 |
+
like social media. Violating these terms may lead to a temporary or permanent
|
| 94 |
+
ban.
|
| 95 |
+
|
| 96 |
+
### 3. Temporary Ban
|
| 97 |
+
|
| 98 |
+
**Community Impact**: A serious violation of community standards, including
|
| 99 |
+
sustained inappropriate behavior.
|
| 100 |
+
|
| 101 |
+
**Consequence**: A temporary ban from any sort of interaction or public
|
| 102 |
+
communication with the community for a specified period of time. No public or
|
| 103 |
+
private interaction with the people involved, including unsolicited interaction
|
| 104 |
+
with those enforcing the Code of Conduct, is allowed during this period.
|
| 105 |
+
Violating these terms may lead to a permanent ban.
|
| 106 |
+
|
| 107 |
+
### 4. Permanent Ban
|
| 108 |
+
|
| 109 |
+
**Community Impact**: Demonstrating a pattern of violation of community
|
| 110 |
+
standards, including sustained inappropriate behavior, harassment of an
|
| 111 |
+
individual, or aggression toward or disparagement of classes of individuals.
|
| 112 |
+
|
| 113 |
+
**Consequence**: A permanent ban from any sort of public interaction within the
|
| 114 |
+
community.
|
| 115 |
+
|
| 116 |
+
## Attribution
|
| 117 |
+
|
| 118 |
+
This Code of Conduct is adapted from the [Contributor Covenant][homepage],
|
| 119 |
+
version 2.1, available at
|
| 120 |
+
[https://www.contributor-covenant.org/version/2/1/code_of_conduct.html][v2.1].
|
| 121 |
+
|
| 122 |
+
Community Impact Guidelines were inspired by
|
| 123 |
+
[Mozilla's code of conduct enforcement ladder][Mozilla CoC].
|
| 124 |
+
|
| 125 |
+
For answers to common questions about this code of conduct, see the FAQ at
|
| 126 |
+
[https://www.contributor-covenant.org/faq][FAQ]. Translations are available at
|
| 127 |
+
[https://www.contributor-covenant.org/translations][translations].
|
| 128 |
+
|
| 129 |
+
[homepage]: https://www.contributor-covenant.org
|
| 130 |
+
[v2.1]: https://www.contributor-covenant.org/version/2/1/code_of_conduct.html
|
| 131 |
+
[Mozilla CoC]: https://github.com/mozilla/diversity
|
| 132 |
+
[FAQ]: https://www.contributor-covenant.org/faq
|
| 133 |
+
[translations]: https://www.contributor-covenant.org/translations
|
ICL/RL/trl_source/CONTRIBUTING.md
ADDED
|
@@ -0,0 +1,411 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# How to contribute to TRL?
|
| 2 |
+
|
| 3 |
+
Everyone is welcome to contribute, and we value everybody's contribution. Code contributions are not the only way to help the community. Answering questions, helping others, and improving the documentation are also immensely valuable.
|
| 4 |
+
|
| 5 |
+
It also helps us if you spread the word! Reference the library in blog posts about the awesome projects it made possible, shout out on Twitter every time it has helped you, or simply ⭐️ the repository to say thank you.
|
| 6 |
+
|
| 7 |
+
However you choose to contribute, please be mindful and respect our [code of conduct](https://github.com/huggingface/trl/blob/main/CODE_OF_CONDUCT.md).
|
| 8 |
+
|
| 9 |
+
**This guide was heavily inspired by the awesome [scikit-learn guide to contributing](https://github.com/scikit-learn/scikit-learn/blob/main/CONTRIBUTING.md).**
|
| 10 |
+
|
| 11 |
+
## Ways to contribute
|
| 12 |
+
|
| 13 |
+
There are several ways you can contribute to TRL:
|
| 14 |
+
|
| 15 |
+
* Fix outstanding issues with the existing code.
|
| 16 |
+
* Submit issues related to bugs or desired new features.
|
| 17 |
+
* Implement trainers for new post-training algorithms.
|
| 18 |
+
* Contribute to the examples or the documentation.
|
| 19 |
+
|
| 20 |
+
If you don't know where to start, there is a special [Good First Issue](https://github.com/huggingface/trl/labels/%F0%9F%91%B6%20good%20first%20issue) listing. It will give you a list of open issues that are beginner-friendly and help you start contributing to open-source. The best way to do that is to open a Pull Request and link it to the issue that you'd like to work on. We try to give priority to opened PRs as we can easily track the progress of the fix, and if the contributor does not have time anymore, someone else can take the PR over.
|
| 21 |
+
|
| 22 |
+
For something slightly more challenging, you can also take a look at the [Good Second Issue](https://github.com/huggingface/trl/labels/%F0%9F%A7%92%20good%20second%20issue) list. In general though, if you feel like you know what you're doing, go for it and we'll help you get there! 🚀
|
| 23 |
+
|
| 24 |
+
> All contributions are equally valuable to the community. 🥰
|
| 25 |
+
|
| 26 |
+
Before you start contributing make sure you have installed all the dev tools:
|
| 27 |
+
|
| 28 |
+
```bash
|
| 29 |
+
pip install -e .[dev]
|
| 30 |
+
```
|
| 31 |
+
|
| 32 |
+
## Fixing outstanding issues
|
| 33 |
+
|
| 34 |
+
If you notice an issue with the existing code and have a fix in mind, feel free to [start contributing](#submitting-a-pull-request-pr) and open a Pull Request!
|
| 35 |
+
|
| 36 |
+
## Submitting a bug-related issue or feature request
|
| 37 |
+
|
| 38 |
+
Do your best to follow these guidelines when submitting a bug-related issue or a feature request. It will make it easier for us to come back to you quickly and with good feedback.
|
| 39 |
+
|
| 40 |
+
### Did you find a bug?
|
| 41 |
+
|
| 42 |
+
The TRL library is robust and reliable thanks to users who report the problems they encounter.
|
| 43 |
+
|
| 44 |
+
Before you report an issue, we would really appreciate it if you could **make sure the bug was not already reported** (use the search bar on GitHub under Issues). Your issue should also be related to bugs in the library itself, and not your code.
|
| 45 |
+
|
| 46 |
+
Once you've confirmed the bug hasn't already been reported, please include the following information in your issue so we can quickly resolve it:
|
| 47 |
+
|
| 48 |
+
* Your **OS type and version**, **Python**, **PyTorch**, **TRL** and **Transformers** versions.
|
| 49 |
+
* A short, self-contained, code snippet that allows us to reproduce the bug in less than 30s.
|
| 50 |
+
* The *full* traceback if an exception is raised.
|
| 51 |
+
* Attach any other additional information, like screenshots, you think may help.
|
| 52 |
+
|
| 53 |
+
To get the OS and software versions automatically, run the following command:
|
| 54 |
+
|
| 55 |
+
```bash
|
| 56 |
+
trl env
|
| 57 |
+
```
|
| 58 |
+
|
| 59 |
+
### Do you want a new feature?
|
| 60 |
+
|
| 61 |
+
If there is a new feature you'd like to see in TRL, please open an issue and describe:
|
| 62 |
+
|
| 63 |
+
1. What is the *motivation* behind this feature? Is it related to a problem or frustration with the library? Is it a feature related to something you need for a project? Is it something you worked on and think it could benefit the community?
|
| 64 |
+
|
| 65 |
+
Whatever it is, we'd love to hear about it!
|
| 66 |
+
|
| 67 |
+
2. Describe your requested feature in as much detail as possible. The more you can tell us about it, the better we'll be able to help you.
|
| 68 |
+
3. Provide a *code snippet* that demonstrates the feature's usage.
|
| 69 |
+
4. If the feature is related to a paper, please include a link.
|
| 70 |
+
|
| 71 |
+
If your issue is well written we're already 80% of the way there by the time you create it.
|
| 72 |
+
|
| 73 |
+
## Do you want to implement a new trainer?
|
| 74 |
+
|
| 75 |
+
New post-training methods are published frequently and those that satisfy the following criteria are good candidates to be integrated into TRL:
|
| 76 |
+
|
| 77 |
+
* **Simplicity:** Does the new method achieve similar performance as prior methods, but with less complexity? A good example is Direct Preference Optimization (DPO) [[Rafailov et al, 2023]](https://huggingface.co/papers/2305.18290), which provided a simpler and compelling alternative to RLHF methods.
|
| 78 |
+
* **Efficiency:** Does the new method provide a significant improvement in training efficiency? A good example is Odds Ratio Preference Optimization (ORPO) [[Hong et al, 2023]](https://huggingface.co/papers/2403.07691), which utilizes a similar objective as DPO but requires half the GPU VRAM.
|
| 79 |
+
|
| 80 |
+
Methods that only provide incremental improvements at the expense of added complexity or compute costs are unlikely to be included in TRL.
|
| 81 |
+
|
| 82 |
+
If you want to implement a trainer for a new post-training method, first open an issue and provide the following information:
|
| 83 |
+
|
| 84 |
+
* A short description of the method and a link to the paper.
|
| 85 |
+
* Link to the implementation if it is open-sourced.
|
| 86 |
+
* Link to model weights trained with the method if they are available.
|
| 87 |
+
|
| 88 |
+
Based on the community and maintainer feedback, the next step will be to implement the trainer and config classes. See the following examples for inspiration:
|
| 89 |
+
|
| 90 |
+
* Paired preference optimisation: [`dpo_trainer.py`](./trl/trainer/dpo_trainer.py) and [`dpo_config.py`](./trl/trainer/dpo_config.py)
|
| 91 |
+
* RL-based optimisation: [`rloo_trainer.py`](./trl/trainer/rloo_trainer.py) and [`rloo_config.py`](./trl/trainer/rloo_config.py)
|
| 92 |
+
* Online optimisation: [`online_dpo_trainer.py`](./trl/trainer/online_dpo_trainer.py) and [`online_dpo_config.py`](./trl/trainer/online_dpo_config.py)
|
| 93 |
+
|
| 94 |
+
## Do you want to add documentation?
|
| 95 |
+
|
| 96 |
+
We're always looking for improvements to the documentation that make it more clear and accurate. Please let us know how the documentation can be improved, such as typos, dead links, and any missing, unclear, or inaccurate content... We'll be happy to make the changes or help you contribute if you're interested!
|
| 97 |
+
|
| 98 |
+
## Submitting a pull request (PR)
|
| 99 |
+
|
| 100 |
+
Before writing code, we strongly advise you to search through the existing PRs or issues to make sure that nobody is already working on the same thing. If you are unsure, it is always a good idea to open an issue to get some feedback.
|
| 101 |
+
|
| 102 |
+
You will need basic `git` proficiency to be able to contribute to TRL. `git` is not the easiest tool to use but it has the greatest manual. Type `git --help` in a shell and enjoy. If you prefer books, [Pro Git](https://git-scm.com/book/en/v2) is a very good reference.
|
| 103 |
+
|
| 104 |
+
Follow these steps to start contributing:
|
| 105 |
+
|
| 106 |
+
1. Fork the [repository](https://github.com/huggingface/trl) by clicking on the 'Fork' button on the repository's page. This creates a copy of the code under your GitHub user account.
|
| 107 |
+
|
| 108 |
+
2. Clone your fork to your local disk, and add the base repository as a remote. The following command assumes you have your public SSH key uploaded to GitHub. See the following guide for more [information](https://docs.github.com/en/repositories/creating-and-managing-repositories/cloning-a-repository).
|
| 109 |
+
|
| 110 |
+
```bash
|
| 111 |
+
git clone git@github.com:<your Github handle>/trl.git
|
| 112 |
+
cd trl
|
| 113 |
+
git remote add upstream https://github.com/huggingface/trl.git
|
| 114 |
+
```
|
| 115 |
+
|
| 116 |
+
3. Create a new branch to hold your development changes, and do this for every new PR you work on.
|
| 117 |
+
|
| 118 |
+
Start by synchronizing your `main` branch with the `upstream/main` branch (more details in the [GitHub Docs](https://docs.github.com/en/github/collaborating-with-issues-and-pull-requests/syncing-a-fork)):
|
| 119 |
+
|
| 120 |
+
```bash
|
| 121 |
+
git checkout main
|
| 122 |
+
git fetch upstream
|
| 123 |
+
git merge upstream/main
|
| 124 |
+
```
|
| 125 |
+
|
| 126 |
+
Once your `main` branch is synchronized, create a new branch from it:
|
| 127 |
+
|
| 128 |
+
```bash
|
| 129 |
+
git checkout -b a-descriptive-name-for-my-changes
|
| 130 |
+
```
|
| 131 |
+
|
| 132 |
+
**Do not** work on the `main` branch.
|
| 133 |
+
|
| 134 |
+
4. Set up a development environment by running the following command in a conda or a virtual environment you've created for working on this library:
|
| 135 |
+
|
| 136 |
+
```bash
|
| 137 |
+
pip install -e .[dev]
|
| 138 |
+
```
|
| 139 |
+
|
| 140 |
+
(If TRL was already installed in the virtual environment, remove it with `pip uninstall trl` before reinstalling it.)
|
| 141 |
+
|
| 142 |
+
Alternatively, if you are using [Visual Studio Code](https://code.visualstudio.com/Download), the fastest way to get set up is by using the provided Dev Container. Check [the documentation on how to get started with dev containers](https://code.visualstudio.com/docs/remote/containers).
|
| 143 |
+
|
| 144 |
+
5. Develop the features on your branch.
|
| 145 |
+
|
| 146 |
+
As you work on the features, you should make sure that the test suite passes. You should run the tests impacted by your changes like this (see below an explanation regarding the environment variable):
|
| 147 |
+
|
| 148 |
+
```bash
|
| 149 |
+
pytest tests/<TEST_TO_RUN>.py
|
| 150 |
+
```
|
| 151 |
+
|
| 152 |
+
> For the following commands leveraging the `make` utility.
|
| 153 |
+
|
| 154 |
+
You can also run the full suite with the following command.
|
| 155 |
+
|
| 156 |
+
```bash
|
| 157 |
+
make test
|
| 158 |
+
```
|
| 159 |
+
|
| 160 |
+
TRL relies on `ruff` for maintaining consistent code formatting across its source files. Before submitting any PR, you should apply automatic style corrections and run code verification checks.
|
| 161 |
+
|
| 162 |
+
We provide a `precommit` target in the `Makefile` that simplifies this process by running all required checks and optimizations on only the files modified by your PR.
|
| 163 |
+
|
| 164 |
+
To apply these checks and corrections in one step, use:
|
| 165 |
+
|
| 166 |
+
```bash
|
| 167 |
+
make precommit
|
| 168 |
+
```
|
| 169 |
+
|
| 170 |
+
This command runs the following:
|
| 171 |
+
|
| 172 |
+
* Executes `pre-commit` hooks to automatically fix style issues with `ruff` and other tools.
|
| 173 |
+
* Runs additional scripts such as adding copyright information.
|
| 174 |
+
|
| 175 |
+
If you prefer to apply the style corrections separately or review them individually, the `pre-commit` hook will handle the formatting for the files in question.
|
| 176 |
+
|
| 177 |
+
Once you're happy with your changes, add changed files using `git add` and make a commit with `git commit` to record your changes locally:
|
| 178 |
+
|
| 179 |
+
```bash
|
| 180 |
+
git add modified_file.py
|
| 181 |
+
git commit
|
| 182 |
+
```
|
| 183 |
+
|
| 184 |
+
Please write [good commit messages](https://chris.beams.io/posts/git-commit/).
|
| 185 |
+
|
| 186 |
+
It is a good idea to sync your copy of the code with the original
|
| 187 |
+
repository regularly. This way you can quickly account for changes:
|
| 188 |
+
|
| 189 |
+
```bash
|
| 190 |
+
git fetch upstream
|
| 191 |
+
git rebase upstream/main
|
| 192 |
+
```
|
| 193 |
+
|
| 194 |
+
Push the changes to your account using:
|
| 195 |
+
|
| 196 |
+
```bash
|
| 197 |
+
git push -u origin a-descriptive-name-for-my-changes
|
| 198 |
+
```
|
| 199 |
+
|
| 200 |
+
6. Once you are satisfied (**and the checklist below is happy too**), go to the webpage of your fork on GitHub. Click on 'Pull request' to send your changes to the project maintainers for review.
|
| 201 |
+
|
| 202 |
+
7. It's ok if maintainers ask you for changes. It happens to core contributors too! To ensure everyone can review your changes in the pull request, work on your local branch and push the updates to your fork. They will automatically appear in the pull request.
|
| 203 |
+
|
| 204 |
+
### Checklist
|
| 205 |
+
|
| 206 |
+
1. The title of your pull request should be a summary of its contribution;
|
| 207 |
+
2. If your pull request addresses an issue, please mention the issue number in the pull request description to make sure they are linked (and people consulting the issue know you are working on it);
|
| 208 |
+
3. To indicate a work in progress please prefix the title with `[WIP]`, or mark the PR as a draft PR. These are useful to avoid duplicated work, and to differentiate it from PRs ready to be merged;
|
| 209 |
+
4. Make sure existing tests pass;
|
| 210 |
+
5. Add high-coverage tests. No quality testing = no merge.
|
| 211 |
+
|
| 212 |
+
### Tests
|
| 213 |
+
|
| 214 |
+
An extensive test suite is included to test the library behavior and several examples. Library tests can be found in
|
| 215 |
+
the [tests folder](https://github.com/huggingface/trl/tree/main/tests).
|
| 216 |
+
|
| 217 |
+
We use `pytest` to run the tests. From the root of the
|
| 218 |
+
repository here's how to run tests with `pytest` for the library:
|
| 219 |
+
|
| 220 |
+
```bash
|
| 221 |
+
python -m pytest -sv ./tests
|
| 222 |
+
```
|
| 223 |
+
|
| 224 |
+
That's how `make test` is implemented (without the `pip install` line)!
|
| 225 |
+
|
| 226 |
+
You can specify a smaller set of tests to test only the feature
|
| 227 |
+
you're working on.
|
| 228 |
+
|
| 229 |
+
### Default values guidelines
|
| 230 |
+
|
| 231 |
+
1. **Use defaults when appropriate**:
|
| 232 |
+
|
| 233 |
+
Provide default values unless the parameter's value varies significantly by use case. For example, datasets or models should not have defaults, but parameters like `learning_rate` should.
|
| 234 |
+
|
| 235 |
+
2. **Prioritize proven defaults**:
|
| 236 |
+
|
| 237 |
+
Default values should align with those recommended in the original paper or method. Alternatives require strong evidence of superior performance in most cases.
|
| 238 |
+
|
| 239 |
+
3. **Ensure safety and predictability**:
|
| 240 |
+
|
| 241 |
+
Defaults must be safe, expected and reliable. Avoid settings that could lead to surprising outcomes, such as excessive memory usage or poor performance in edge cases.
|
| 242 |
+
|
| 243 |
+
4. **Balance consistency and flexibility**:
|
| 244 |
+
|
| 245 |
+
Aim for consistent defaults across similar functions or methods. However, consistency should not be preferred to point 2 or 3.
|
| 246 |
+
|
| 247 |
+
5. **Opt-in for new features**:
|
| 248 |
+
|
| 249 |
+
Do not enable new features or improvements (e.g., novel loss functions) by default. Users should explicitly opt-in to use these.
|
| 250 |
+
|
| 251 |
+
### Writing documentation
|
| 252 |
+
|
| 253 |
+
High-quality documentation is crucial for maintaining a project that is easy to use, understand, and extend. When adding new features, ensure they are thoroughly documented to maintain consistency and clarity throughout the project.
|
| 254 |
+
|
| 255 |
+
To illustrate what good documentation looks like, here’s an example of a well-documented function:
|
| 256 |
+
|
| 257 |
+
````python
|
| 258 |
+
def replicate_str(string: str, n: int, sep: str = " ") -> str:
|
| 259 |
+
r"""
|
| 260 |
+
Replicate a string `n` times with a separator.
|
| 261 |
+
|
| 262 |
+
Args:
|
| 263 |
+
string (`str`):
|
| 264 |
+
String to replicate.
|
| 265 |
+
n (`int`):
|
| 266 |
+
Number of times to replicate the string.
|
| 267 |
+
sep (`str`, *optional*, defaults to `" "`):
|
| 268 |
+
Separator to use between each replication.
|
| 269 |
+
|
| 270 |
+
Returns:
|
| 271 |
+
`str`: The replicated string.
|
| 272 |
+
|
| 273 |
+
Examples:
|
| 274 |
+
```python
|
| 275 |
+
>>> replicate_str("hello", 3)
|
| 276 |
+
"hello hello hello"
|
| 277 |
+
>>> replicate_str("hello", 3, sep=", ")
|
| 278 |
+
"hello, hello, hello"
|
| 279 |
+
```
|
| 280 |
+
"""
|
| 281 |
+
return sep.join([string] * n)
|
| 282 |
+
````
|
| 283 |
+
|
| 284 |
+
* **Line Wrapping:** Applied a consistent line wrap at column 120 to improve readability.
|
| 285 |
+
* **Definite Articles:** Removed definite articles where possible to streamline language. (Eg: Changed "The string to replicate" to "String to replicate")
|
| 286 |
+
* **Type Annotations:**
|
| 287 |
+
* Always include type definitions, indicating if a parameter is optional and specifying the default value.
|
| 288 |
+
|
| 289 |
+
* **String Defaults:**
|
| 290 |
+
* Ensured that default string values are wrapped in double quotes:
|
| 291 |
+
|
| 292 |
+
```txt
|
| 293 |
+
defaults to `"foo"`
|
| 294 |
+
```
|
| 295 |
+
|
| 296 |
+
* **Dictionary Typing:**
|
| 297 |
+
* Replaced generic `dict` type hints with more explicit `dict[str, Any]` to clarify expected key-value pairs.
|
| 298 |
+
* **Default Value Formatting:**
|
| 299 |
+
* Consistently surrounded default values with backticks for improved formatting:
|
| 300 |
+
|
| 301 |
+
```txt
|
| 302 |
+
defaults to `4`
|
| 303 |
+
```
|
| 304 |
+
|
| 305 |
+
* **Sub-sectioning:** When the number of arguments is large, consider breaking them into sub-sections for better readability.
|
| 306 |
+
|
| 307 |
+
```python
|
| 308 |
+
def calculate_statistics(data: list[float], precision: int = 2, include_variance: bool = False) -> dict[str, float]:
|
| 309 |
+
r"""
|
| 310 |
+
Calculates basic statistics for a given dataset.
|
| 311 |
+
|
| 312 |
+
Args:
|
| 313 |
+
> Data inputs
|
| 314 |
+
|
| 315 |
+
data (`list[float]`):
|
| 316 |
+
A list of numerical values to analyze.
|
| 317 |
+
|
| 318 |
+
> Configuration parameters
|
| 319 |
+
|
| 320 |
+
precision (`int`, *optional*, defaults to `2`):
|
| 321 |
+
Number of decimal places to round the results.
|
| 322 |
+
include_variance (`bool`, *optional*, defaults to `False`):
|
| 323 |
+
Whether to include the variance of the dataset in the results.
|
| 324 |
+
|
| 325 |
+
Returns:
|
| 326 |
+
`dict[str, float]`:
|
| 327 |
+
A dictionary containing calculated statistics such as mean, median, and optionally variance.
|
| 328 |
+
"""
|
| 329 |
+
...
|
| 330 |
+
```
|
| 331 |
+
|
| 332 |
+
### Deprecation and backward compatibility
|
| 333 |
+
|
| 334 |
+
Our approach to deprecation and backward compatibility is flexible and based on the feature’s usage and impact. Each deprecation is carefully evaluated, aiming to balance innovation with user needs.
|
| 335 |
+
|
| 336 |
+
When a feature or component is marked for deprecation, its use will emit a warning message. This warning will include:
|
| 337 |
+
|
| 338 |
+
* **Transition Guidance**: Instructions on how to migrate to the alternative solution or replacement.
|
| 339 |
+
* **Removal Version**: The target version when the feature will be removed, providing users with a clear timeframe to transition.
|
| 340 |
+
|
| 341 |
+
Example:
|
| 342 |
+
|
| 343 |
+
```python
|
| 344 |
+
warnings.warn(
|
| 345 |
+
"The `Trainer.foo` method is deprecated and will be removed in version 0.14.0. "
|
| 346 |
+
"Please use the `Trainer.bar` class instead.",
|
| 347 |
+
FutureWarning,
|
| 348 |
+
stacklevel=2,
|
| 349 |
+
)
|
| 350 |
+
```
|
| 351 |
+
|
| 352 |
+
The deprecation and removal schedule is based on each feature's usage and impact, with examples at two extremes:
|
| 353 |
+
|
| 354 |
+
* **Experimental or Low-Use Features**: For a feature that is experimental or has limited usage, backward compatibility may not be maintained between releases. Users should therefore anticipate potential breaking changes from one version to the next.
|
| 355 |
+
|
| 356 |
+
* **Widely-Used Components**: For a feature with high usage, we aim for a more gradual transition period of approximately **5 months**, generally scheduling deprecation around **5 minor releases** after the initial warning.
|
| 357 |
+
|
| 358 |
+
These examples represent the two ends of a continuum. The specific timeline for each feature will be determined individually, balancing innovation with user stability needs.
|
| 359 |
+
|
| 360 |
+
### Working with warnings
|
| 361 |
+
|
| 362 |
+
Warnings play a critical role in guiding users toward resolving potential issues, but they should be used thoughtfully to avoid unnecessary noise. Unlike logging, which provides informational context or operational details, warnings signal conditions that require attention and action. Overusing warnings can dilute their importance, leading users to ignore them entirely.
|
| 363 |
+
|
| 364 |
+
#### Definitions
|
| 365 |
+
|
| 366 |
+
* **Correct**: An operation is correct if it is valid, follows the intended approach, and aligns with the current best practices or guidelines within the codebase. This is the recommended or intended way to perform the operation.
|
| 367 |
+
* **Supported**: An operation is supported if it is technically valid and works within the current codebase, but it may not be the most efficient, optimal, or recommended way to perform the task. This includes deprecated features or legacy approaches that still work but may be phased out in the future.
|
| 368 |
+
|
| 369 |
+
#### Choosing the right message
|
| 370 |
+
|
| 371 |
+
* **Correct → No warning**:
|
| 372 |
+
If the operation is fully valid and expected, no message should be issued. The system is working as intended, so no warning is necessary.
|
| 373 |
+
|
| 374 |
+
* **Correct but deserves attention → No warning, possibly a log message**:
|
| 375 |
+
When an operation is correct but uncommon or requires special attention, providing an informational message can be helpful. This keeps users informed without implying any issue. If available, use the logger to output this message. Example:
|
| 376 |
+
|
| 377 |
+
```python
|
| 378 |
+
logger.info("This is an informational message about a rare but correct operation.")
|
| 379 |
+
```
|
| 380 |
+
|
| 381 |
+
* **Correct but very likely a mistake → Warning with option to disable**:
|
| 382 |
+
In rare cases, you may want to issue a warning for a correct operation that’s very likely a mistake. In such cases, you must provide an option to suppress the warning. This can be done with a flag in the function. Example:
|
| 383 |
+
|
| 384 |
+
```python
|
| 385 |
+
def my_function(foo, bar, _warn=True):
|
| 386 |
+
if foo == bar:
|
| 387 |
+
if _warn:
|
| 388 |
+
logger.warning("foo and bar are the same, this is likely a mistake. Ignore this warning by setting `_warn=False`.")
|
| 389 |
+
# Do something
|
| 390 |
+
```
|
| 391 |
+
|
| 392 |
+
* **Supported but not correct → Warning**:
|
| 393 |
+
If the operation is technically supported but is deprecated, suboptimal, or could cause future issues (e.g., conflicting arguments), a warning should be raised. This message should be actionable, meaning it must explain how to resolve the issue. Example:
|
| 394 |
+
|
| 395 |
+
```python
|
| 396 |
+
def my_function(foo, bar):
|
| 397 |
+
if foo and bar:
|
| 398 |
+
logger.warning("Both `foo` and `bar` were provided, but only one is allowed. Ignoring `foo`. Please pass only one of these arguments.")
|
| 399 |
+
# Do something
|
| 400 |
+
```
|
| 401 |
+
|
| 402 |
+
* **Not supported → Exception**:
|
| 403 |
+
If the operation is invalid or unsupported, raise an exception. This indicates that the operation cannot be performed and requires immediate attention. Example:
|
| 404 |
+
|
| 405 |
+
```python
|
| 406 |
+
def my_function(foo, bar):
|
| 407 |
+
if foo and bar:
|
| 408 |
+
raise ValueError("Both `foo` and `bar` were provided, but only one is allowed. Please pass only one of these arguments.")
|
| 409 |
+
```
|
| 410 |
+
|
| 411 |
+
By following this classification, you ensure that warnings, information, and exceptions are used appropriately, providing clear guidance to the user without cluttering the system with unnecessary messages.
|
ICL/RL/trl_source/LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright 2020-2026 The HuggingFace Team
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
ICL/RL/trl_source/MANIFEST.in
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
include LICENSE
|
| 2 |
+
include CONTRIBUTING.md
|
| 3 |
+
include README.md
|
| 4 |
+
include trl/accelerate_configs/*.yaml
|
| 5 |
+
include trl/templates/*.md
|
| 6 |
+
recursive-exclude * __pycache__
|
| 7 |
+
prune tests
|
ICL/RL/trl_source/Makefile
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.PHONY: test precommit common_tests slow_tests tests_gpu test_experimental
|
| 2 |
+
|
| 3 |
+
check_dirs := examples tests trl
|
| 4 |
+
|
| 5 |
+
ACCELERATE_CONFIG_PATH = `pwd`/examples/accelerate_configs
|
| 6 |
+
|
| 7 |
+
test:
|
| 8 |
+
pytest -n auto -m "not slow and not low_priority" -s -v --reruns 5 --reruns-delay 1 --only-rerun '(OSError|Timeout|HTTPError.*502|HTTPError.*504||not less than or equal to 0.01)' tests
|
| 9 |
+
|
| 10 |
+
precommit:
|
| 11 |
+
python scripts/add_copyrights.py
|
| 12 |
+
pre-commit run --all-files
|
| 13 |
+
doc-builder style trl tests docs/source --max_len 119
|
| 14 |
+
|
| 15 |
+
slow_tests:
|
| 16 |
+
pytest -m "slow" tests/ $(if $(IS_GITHUB_CI),--report-log "slow_tests.log",)
|
| 17 |
+
|
| 18 |
+
test_experimental:
|
| 19 |
+
pytest -n auto -s -v tests/experimental
|
ICL/RL/trl_source/README.md
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# TRL - Transformer Reinforcement Learning
|
| 2 |
+
|
| 3 |
+
<div style="text-align: center">
|
| 4 |
+
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl_banner_dark.png" alt="TRL Banner">
|
| 5 |
+
</div>
|
| 6 |
+
|
| 7 |
+
<hr> <br>
|
| 8 |
+
|
| 9 |
+
<h3 align="center">
|
| 10 |
+
<p>A comprehensive library to post-train foundation models</p>
|
| 11 |
+
</h3>
|
| 12 |
+
|
| 13 |
+
<p align="center">
|
| 14 |
+
<a href="https://github.com/huggingface/trl/blob/main/LICENSE"><img alt="License" src="https://img.shields.io/github/license/huggingface/trl.svg?color=blue"></a>
|
| 15 |
+
<a href="https://huggingface.co/docs/trl/index"><img alt="Documentation" src="https://img.shields.io/website?label=documentation&url=https%3A%2F%2Fhuggingface.co%2Fdocs%2Ftrl%2Findex&down_color=red&down_message=offline&up_color=blue&up_message=online"></a>
|
| 16 |
+
<a href="https://github.com/huggingface/trl/releases"><img alt="GitHub release" src="https://img.shields.io/github/release/huggingface/trl.svg"></a>
|
| 17 |
+
<a href="https://huggingface.co/trl-lib"><img alt="Hugging Face Hub" src="https://img.shields.io/badge/🤗%20Hub-trl--lib-yellow"></a>
|
| 18 |
+
</p>
|
| 19 |
+
|
| 20 |
+
## 🎉 What's New
|
| 21 |
+
|
| 22 |
+
**OpenEnv Integration:** TRL now supports **[OpenEnv](https://huggingface.co/blog/openenv)**, the open-source framework from Meta for defining, deploying, and interacting with environments in reinforcement learning and agentic workflows.
|
| 23 |
+
|
| 24 |
+
Explore how to seamlessly integrate TRL with OpenEnv in our [dedicated documentation](https://huggingface.co/docs/trl/openenv).
|
| 25 |
+
|
| 26 |
+
## Overview
|
| 27 |
+
|
| 28 |
+
TRL is a cutting-edge library designed for post-training foundation models using advanced techniques like Supervised Fine-Tuning (SFT), Group Relative Policy Optimization (GRPO), and Direct Preference Optimization (DPO). Built on top of the [🤗 Transformers](https://github.com/huggingface/transformers) ecosystem, TRL supports a variety of model architectures and modalities, and can be scaled-up across various hardware setups.
|
| 29 |
+
|
| 30 |
+
## Highlights
|
| 31 |
+
|
| 32 |
+
- **Trainers**: Various fine-tuning methods are easily accessible via trainers like [`SFTTrainer`](https://huggingface.co/docs/trl/sft_trainer), [`GRPOTrainer`](https://huggingface.co/docs/trl/grpo_trainer), [`DPOTrainer`](https://huggingface.co/docs/trl/dpo_trainer), [`RewardTrainer`](https://huggingface.co/docs/trl/reward_trainer) and more.
|
| 33 |
+
|
| 34 |
+
- **Efficient and scalable**:
|
| 35 |
+
- Leverages [🤗 Accelerate](https://github.com/huggingface/accelerate) to scale from single GPU to multi-node clusters using methods like [DDP](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html) and [DeepSpeed](https://github.com/deepspeedai/DeepSpeed).
|
| 36 |
+
- Full integration with [🤗 PEFT](https://github.com/huggingface/peft) enables training on large models with modest hardware via quantization and LoRA/QLoRA.
|
| 37 |
+
- Integrates [🦥 Unsloth](https://github.com/unslothai/unsloth) for accelerating training using optimized kernels.
|
| 38 |
+
|
| 39 |
+
- **Command Line Interface (CLI)**: A simple interface lets you fine-tune with models without needing to write code.
|
| 40 |
+
|
| 41 |
+
## Installation
|
| 42 |
+
|
| 43 |
+
### Python Package
|
| 44 |
+
|
| 45 |
+
Install the library using `pip`:
|
| 46 |
+
|
| 47 |
+
```bash
|
| 48 |
+
pip install trl
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
### From source
|
| 52 |
+
|
| 53 |
+
If you want to use the latest features before an official release, you can install TRL from source:
|
| 54 |
+
|
| 55 |
+
```bash
|
| 56 |
+
pip install git+https://github.com/huggingface/trl.git
|
| 57 |
+
```
|
| 58 |
+
|
| 59 |
+
### Repository
|
| 60 |
+
|
| 61 |
+
If you want to use the examples you can clone the repository with the following command:
|
| 62 |
+
|
| 63 |
+
```bash
|
| 64 |
+
git clone https://github.com/huggingface/trl.git
|
| 65 |
+
```
|
| 66 |
+
|
| 67 |
+
## Quick Start
|
| 68 |
+
|
| 69 |
+
For more flexibility and control over training, TRL provides dedicated trainer classes to post-train language models or PEFT adapters on a custom dataset. Each trainer in TRL is a light wrapper around the 🤗 Transformers trainer and natively supports distributed training methods like DDP, DeepSpeed ZeRO, and FSDP.
|
| 70 |
+
|
| 71 |
+
### `SFTTrainer`
|
| 72 |
+
|
| 73 |
+
Here is a basic example of how to use the [`SFTTrainer`](https://huggingface.co/docs/trl/sft_trainer):
|
| 74 |
+
|
| 75 |
+
```python
|
| 76 |
+
from trl import SFTTrainer
|
| 77 |
+
from datasets import load_dataset
|
| 78 |
+
|
| 79 |
+
dataset = load_dataset("trl-lib/Capybara", split="train")
|
| 80 |
+
|
| 81 |
+
trainer = SFTTrainer(
|
| 82 |
+
model="Qwen/Qwen2.5-0.5B",
|
| 83 |
+
train_dataset=dataset,
|
| 84 |
+
)
|
| 85 |
+
trainer.train()
|
| 86 |
+
```
|
| 87 |
+
|
| 88 |
+
### `GRPOTrainer`
|
| 89 |
+
|
| 90 |
+
[`GRPOTrainer`](https://huggingface.co/docs/trl/grpo_trainer) implements the [Group Relative Policy Optimization (GRPO) algorithm](https://huggingface.co/papers/2402.03300) that is more memory-efficient than PPO and was used to train [Deepseek AI's R1](https://huggingface.co/deepseek-ai/DeepSeek-R1).
|
| 91 |
+
|
| 92 |
+
```python
|
| 93 |
+
from datasets import load_dataset
|
| 94 |
+
from trl import GRPOTrainer
|
| 95 |
+
from trl.rewards import accuracy_reward
|
| 96 |
+
|
| 97 |
+
dataset = load_dataset("trl-lib/DeepMath-103K", split="train")
|
| 98 |
+
|
| 99 |
+
trainer = GRPOTrainer(
|
| 100 |
+
model="Qwen/Qwen2.5-0.5B-Instruct",
|
| 101 |
+
reward_funcs=accuracy_reward,
|
| 102 |
+
train_dataset=dataset,
|
| 103 |
+
)
|
| 104 |
+
trainer.train()
|
| 105 |
+
```
|
| 106 |
+
|
| 107 |
+
> [!NOTE]
|
| 108 |
+
> For reasoning models, use the `reasoning_accuracy_reward()` function for better results.
|
| 109 |
+
|
| 110 |
+
### `DPOTrainer`
|
| 111 |
+
|
| 112 |
+
[`DPOTrainer`](https://huggingface.co/docs/trl/dpo_trainer) implements the popular [Direct Preference Optimization (DPO) algorithm](https://huggingface.co/papers/2305.18290) that was used to post-train [Llama 3](https://huggingface.co/papers/2407.21783) and many other models. Here is a basic example of how to use the `DPOTrainer`:
|
| 113 |
+
|
| 114 |
+
```python
|
| 115 |
+
from datasets import load_dataset
|
| 116 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 117 |
+
from trl import DPOConfig, DPOTrainer
|
| 118 |
+
|
| 119 |
+
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
| 120 |
+
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
| 121 |
+
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
|
| 122 |
+
training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO")
|
| 123 |
+
trainer = DPOTrainer(
|
| 124 |
+
model=model,
|
| 125 |
+
args=training_args,
|
| 126 |
+
train_dataset=dataset,
|
| 127 |
+
processing_class=tokenizer
|
| 128 |
+
)
|
| 129 |
+
trainer.train()
|
| 130 |
+
```
|
| 131 |
+
|
| 132 |
+
### `RewardTrainer`
|
| 133 |
+
|
| 134 |
+
Here is a basic example of how to use the [`RewardTrainer`](https://huggingface.co/docs/trl/reward_trainer):
|
| 135 |
+
|
| 136 |
+
```python
|
| 137 |
+
from trl import RewardTrainer
|
| 138 |
+
from datasets import load_dataset
|
| 139 |
+
|
| 140 |
+
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
|
| 141 |
+
|
| 142 |
+
trainer = RewardTrainer(
|
| 143 |
+
model="Qwen/Qwen2.5-0.5B-Instruct",
|
| 144 |
+
train_dataset=dataset,
|
| 145 |
+
)
|
| 146 |
+
trainer.train()
|
| 147 |
+
```
|
| 148 |
+
|
| 149 |
+
## Command Line Interface (CLI)
|
| 150 |
+
|
| 151 |
+
You can use the TRL Command Line Interface (CLI) to quickly get started with post-training methods like Supervised Fine-Tuning (SFT) or Direct Preference Optimization (DPO):
|
| 152 |
+
|
| 153 |
+
**SFT:**
|
| 154 |
+
|
| 155 |
+
```bash
|
| 156 |
+
trl sft --model_name_or_path Qwen/Qwen2.5-0.5B \
|
| 157 |
+
--dataset_name trl-lib/Capybara \
|
| 158 |
+
--output_dir Qwen2.5-0.5B-SFT
|
| 159 |
+
```
|
| 160 |
+
|
| 161 |
+
**DPO:**
|
| 162 |
+
|
| 163 |
+
```bash
|
| 164 |
+
trl dpo --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \
|
| 165 |
+
--dataset_name argilla/Capybara-Preferences \
|
| 166 |
+
--output_dir Qwen2.5-0.5B-DPO
|
| 167 |
+
```
|
| 168 |
+
|
| 169 |
+
Read more about CLI in the [relevant documentation section](https://huggingface.co/docs/trl/clis) or use `--help` for more details.
|
| 170 |
+
|
| 171 |
+
## Development
|
| 172 |
+
|
| 173 |
+
If you want to contribute to `trl` or customize it to your needs make sure to read the [contribution guide](https://github.com/huggingface/trl/blob/main/CONTRIBUTING.md) and make sure you make a dev install:
|
| 174 |
+
|
| 175 |
+
```bash
|
| 176 |
+
git clone https://github.com/huggingface/trl.git
|
| 177 |
+
cd trl/
|
| 178 |
+
pip install -e .[dev]
|
| 179 |
+
```
|
| 180 |
+
|
| 181 |
+
## Experimental
|
| 182 |
+
|
| 183 |
+
A minimal incubation area is available under `trl.experimental` for unstable / fast-evolving features. Anything there may change or be removed in any release without notice.
|
| 184 |
+
|
| 185 |
+
Example:
|
| 186 |
+
|
| 187 |
+
```python
|
| 188 |
+
from trl.experimental.new_trainer import NewTrainer
|
| 189 |
+
```
|
| 190 |
+
|
| 191 |
+
Read more in the [Experimental docs](https://huggingface.co/docs/trl/experimental_overview).
|
| 192 |
+
|
| 193 |
+
## Citation
|
| 194 |
+
|
| 195 |
+
```bibtex
|
| 196 |
+
@software{vonwerra2020trl,
|
| 197 |
+
title = {{TRL: Transformers Reinforcement Learning}},
|
| 198 |
+
author = {von Werra, Leandro and Belkada, Younes and Tunstall, Lewis and Beeching, Edward and Thrush, Tristan and Lambert, Nathan and Huang, Shengyi and Rasul, Kashif and Gallouédec, Quentin},
|
| 199 |
+
license = {Apache-2.0},
|
| 200 |
+
url = {https://github.com/huggingface/trl},
|
| 201 |
+
year = {2020}
|
| 202 |
+
}
|
| 203 |
+
```
|
| 204 |
+
|
| 205 |
+
## License
|
| 206 |
+
|
| 207 |
+
This repository's source code is available under the [Apache-2.0 License](LICENSE).
|
ICL/RL/trl_source/RELEASE.md
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Making a release
|
| 2 |
+
|
| 3 |
+
> [!NOTE]
|
| 4 |
+
> VERSION needs to be formatted following the `v{major}.{minor}.{patch}` convention. We need to follow this convention to be able to retrieve versioned scripts.
|
| 5 |
+
|
| 6 |
+
## Major/Minor Release
|
| 7 |
+
|
| 8 |
+
### 1. Ensure your local repository is up to date with the upstream repository
|
| 9 |
+
|
| 10 |
+
```bash
|
| 11 |
+
git checkout main
|
| 12 |
+
git pull origin main
|
| 13 |
+
```
|
| 14 |
+
|
| 15 |
+
> [!WARNING]
|
| 16 |
+
> Do not merge other pull requests into `main` until the release is done. This is to ensure that the release is stable and does not include any untested changes. Announce internally (#trl-internal) to other maintainers that you are doing a release and that they must not merge PRs until the release is done.
|
| 17 |
+
|
| 18 |
+
### 2. Create a release branch from main
|
| 19 |
+
|
| 20 |
+
```bash
|
| 21 |
+
git checkout -b release-v{major}.{minor}
|
| 22 |
+
```
|
| 23 |
+
|
| 24 |
+
### 3. Change the version in the following files
|
| 25 |
+
|
| 26 |
+
- `.github/workflows/tests_latest.yml`:
|
| 27 |
+
|
| 28 |
+
```diff
|
| 29 |
+
- with: { ref: v{major}.{minor-1}-release }
|
| 30 |
+
+ with: { ref: v{major}.{minor}-release }
|
| 31 |
+
```
|
| 32 |
+
|
| 33 |
+
- `CITATION.cff`
|
| 34 |
+
|
| 35 |
+
```diff
|
| 36 |
+
- version: '{major}.{minor-1}'
|
| 37 |
+
+ version: '{major}.{minor}'
|
| 38 |
+
```
|
| 39 |
+
|
| 40 |
+
- `VERSION`
|
| 41 |
+
|
| 42 |
+
```diff
|
| 43 |
+
- {major}.{minor}.0.dev0
|
| 44 |
+
+ {major}.{minor}.0
|
| 45 |
+
```
|
| 46 |
+
|
| 47 |
+
### 4. Commit and push these changes
|
| 48 |
+
|
| 49 |
+
```shell
|
| 50 |
+
git add .github/workflows/tests_latest.yml CITATION.cff VERSION
|
| 51 |
+
git commit -m 'Release: {major}.{minor}'
|
| 52 |
+
git push origin release-v{major}.{minor}
|
| 53 |
+
```
|
| 54 |
+
|
| 55 |
+
### 5. Create a pull request
|
| 56 |
+
|
| 57 |
+
from `release-v{major}.{minor}` to `main`, named `Release: v{major}.{minor}`, wait for tests to pass, and request a review.
|
| 58 |
+
|
| 59 |
+
### 6. Once the pull request is approved, merge it into `main`
|
| 60 |
+
|
| 61 |
+
It will automatically publish the new version of the package on PyPI.
|
| 62 |
+
|
| 63 |
+
### 7. Add a tag in git to mark the release
|
| 64 |
+
|
| 65 |
+
```shell
|
| 66 |
+
git checkout main
|
| 67 |
+
git pull origin main
|
| 68 |
+
git tag -a v{major}.{minor}.0 -m 'Adds tag v{major}.{minor}.0 for PyPI'
|
| 69 |
+
git push origin v{major}.{minor}.0
|
| 70 |
+
```
|
| 71 |
+
|
| 72 |
+
### 8. Create a branch `v{major}.{minor}-release` for future patch releases
|
| 73 |
+
|
| 74 |
+
```shell
|
| 75 |
+
git checkout -b v{major}.{minor}-release
|
| 76 |
+
git push origin v{major}.{minor}-release
|
| 77 |
+
```
|
| 78 |
+
|
| 79 |
+
This ensures that future patch releases (`v{major}.{minor}.1`, `v{major}.{minor}.2`, etc.) can be made separately from `main`.
|
| 80 |
+
|
| 81 |
+
### 9. Create a GitHub Release
|
| 82 |
+
|
| 83 |
+
1. Go to the repo’s [releases section](https://github.com/huggingface/trl/releases) on GitHub.
|
| 84 |
+
2. Click **Draft a new release**.
|
| 85 |
+
3. Select the `v{major}.{minor}.0` tag you just created in step 7.
|
| 86 |
+
4. Add a title (`v{major}.{minor}.0`) and a short description of what’s new.
|
| 87 |
+
5. Click **Publish Release**.
|
| 88 |
+
|
| 89 |
+
### 10. Bump to dev version
|
| 90 |
+
|
| 91 |
+
1. Create a branch `bump-dev-version-{major}.{minor+1}` from `main` and checkout to it.
|
| 92 |
+
|
| 93 |
+
```shell
|
| 94 |
+
git checkout -b bump-dev-version-{major}.{minor+1}
|
| 95 |
+
```
|
| 96 |
+
|
| 97 |
+
2. Change the version in file `VERSION`:
|
| 98 |
+
|
| 99 |
+
```diff
|
| 100 |
+
- {major}.{minor}.0
|
| 101 |
+
+ {major}.{minor+1}.0.dev0
|
| 102 |
+
```
|
| 103 |
+
|
| 104 |
+
3. Commit and push these changes
|
| 105 |
+
|
| 106 |
+
```shell
|
| 107 |
+
git add VERSION
|
| 108 |
+
git commit -m '⬆️ Bump dev version'
|
| 109 |
+
git push origin bump-dev-version-{major}.{minor+1}
|
| 110 |
+
```
|
| 111 |
+
|
| 112 |
+
4. Create a pull request from `bump-dev-version-{major}.{minor+1}` to `main`, named `⬆️ Bump dev version`, and request urgent review.
|
| 113 |
+
|
| 114 |
+
5. Once the pull request is approved, merge it into `main`.
|
| 115 |
+
|
| 116 |
+
6. The codebase is now ready for the next development cycle, inform the team in the #trl-internal channel.
|
| 117 |
+
|
| 118 |
+
## Making a patch release
|
| 119 |
+
|
| 120 |
+
### 1. Ensure your local repository is up to date with the upstream repository
|
| 121 |
+
|
| 122 |
+
```bash
|
| 123 |
+
git checkout v{major}.{minor}-release
|
| 124 |
+
git pull origin main
|
| 125 |
+
```
|
| 126 |
+
|
| 127 |
+
### 2. Cherry-pick the changes you want to include in the patch release
|
| 128 |
+
|
| 129 |
+
```bash
|
| 130 |
+
git cherry-pick <commit-hash-0>
|
| 131 |
+
git cherry-pick <commit-hash-1>
|
| 132 |
+
...
|
| 133 |
+
```
|
| 134 |
+
|
| 135 |
+
### 3. Change the version in the file `VERSION`
|
| 136 |
+
|
| 137 |
+
```diff
|
| 138 |
+
- {major}.{minor}.{patch-1}
|
| 139 |
+
+ {major}.{minor}.{patch}
|
| 140 |
+
```
|
| 141 |
+
|
| 142 |
+
### 4. Commit and push these changes
|
| 143 |
+
|
| 144 |
+
```shell
|
| 145 |
+
git add VERSION
|
| 146 |
+
git commit -m 'Release: {major}.{minor}.{patch}'
|
| 147 |
+
git push origin v{major}.{minor}-release
|
| 148 |
+
```
|
| 149 |
+
|
| 150 |
+
### 5. Wait for the CI to pass
|
| 151 |
+
|
| 152 |
+
The CI will automatically publish the new version of the package on PyPI.
|
| 153 |
+
|
| 154 |
+
### 6. Add a tag in git to mark the release
|
| 155 |
+
|
| 156 |
+
```shell
|
| 157 |
+
git tag -a v{major}.{minor}.{patch} -m 'Adds tag v{major}.{minor}.{patch} for PyPI'
|
| 158 |
+
git push origin v{major}.{minor}.{patch}
|
| 159 |
+
```
|
| 160 |
+
|
| 161 |
+
#### 7. Create a GitHub Release
|
| 162 |
+
|
| 163 |
+
1. Go to the repo’s [releases section](https://github.com/huggingface/trl/releases) on GitHub.
|
| 164 |
+
2. Click **Draft a new release**.
|
| 165 |
+
3. Select the `v{major}.{minor}.{patch}` tag you just created in step 7.
|
| 166 |
+
4. Add a title (`v{major}.{minor}.{patch}`) and a short description of what’s new.
|
| 167 |
+
5. Click **Publish Release**.
|
ICL/RL/trl_source/VERSION
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
0.29.0.dev0
|
ICL/RL/trl_source/pyproject.toml
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["setuptools >= 77.0.3"]
|
| 3 |
+
build-backend = "setuptools.build_meta"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "trl"
|
| 7 |
+
description = "Train transformer language models with reinforcement learning."
|
| 8 |
+
authors = [
|
| 9 |
+
{ name = "Leandro von Werra", email = "leandro.vonwerra@gmail.com" }
|
| 10 |
+
]
|
| 11 |
+
readme = { file = "README.md", content-type = "text/markdown" }
|
| 12 |
+
license = "Apache-2.0"
|
| 13 |
+
license-files = ["LICENSE"]
|
| 14 |
+
keywords = [
|
| 15 |
+
"transformers", "huggingface", "language modeling", "post-training", "rlhf", "sft", "dpo", "grpo"
|
| 16 |
+
]
|
| 17 |
+
classifiers = [
|
| 18 |
+
"Development Status :: 2 - Pre-Alpha",
|
| 19 |
+
"Intended Audience :: Developers",
|
| 20 |
+
"Intended Audience :: Science/Research",
|
| 21 |
+
"Natural Language :: English",
|
| 22 |
+
"Operating System :: OS Independent",
|
| 23 |
+
"Programming Language :: Python :: 3",
|
| 24 |
+
"Programming Language :: Python :: 3.10",
|
| 25 |
+
"Programming Language :: Python :: 3.11",
|
| 26 |
+
"Programming Language :: Python :: 3.12",
|
| 27 |
+
"Programming Language :: Python :: 3.13"
|
| 28 |
+
]
|
| 29 |
+
requires-python = ">=3.10"
|
| 30 |
+
dependencies = [
|
| 31 |
+
"accelerate>=1.4.0",
|
| 32 |
+
"datasets>=3.0.0",
|
| 33 |
+
"packaging>20.0",
|
| 34 |
+
"transformers>=4.56.2",
|
| 35 |
+
]
|
| 36 |
+
dynamic = ["version"]
|
| 37 |
+
|
| 38 |
+
[project.urls]
|
| 39 |
+
Homepage = "https://github.com/huggingface/trl"
|
| 40 |
+
|
| 41 |
+
[project.scripts]
|
| 42 |
+
trl = "trl.cli:main"
|
| 43 |
+
|
| 44 |
+
[project.optional-dependencies]
|
| 45 |
+
bco = [
|
| 46 |
+
"scikit-learn",
|
| 47 |
+
"joblib"
|
| 48 |
+
]
|
| 49 |
+
deepspeed = [
|
| 50 |
+
"deepspeed>=0.14.4",
|
| 51 |
+
"transformers!=5.1.0", # see transformers#43780
|
| 52 |
+
]
|
| 53 |
+
judges = [
|
| 54 |
+
"openai>=1.23.2",
|
| 55 |
+
"llm-blender>=0.0.2",
|
| 56 |
+
"transformers<5.0.0", # see #4918
|
| 57 |
+
]
|
| 58 |
+
kernels = [
|
| 59 |
+
"kernels"
|
| 60 |
+
]
|
| 61 |
+
liger = [
|
| 62 |
+
"liger-kernel>=0.6.4"
|
| 63 |
+
]
|
| 64 |
+
peft = [
|
| 65 |
+
"peft>=0.8.0"
|
| 66 |
+
]
|
| 67 |
+
quality = [
|
| 68 |
+
"pre-commit",
|
| 69 |
+
"hf-doc-builder"
|
| 70 |
+
]
|
| 71 |
+
quantization = [
|
| 72 |
+
"bitsandbytes"
|
| 73 |
+
]
|
| 74 |
+
scikit = [
|
| 75 |
+
"scikit-learn"
|
| 76 |
+
]
|
| 77 |
+
test = [
|
| 78 |
+
"pytest-cov",
|
| 79 |
+
"pytest-datadir>=1.7.0", # lazy datadirs
|
| 80 |
+
"pytest-rerunfailures==15.1",
|
| 81 |
+
"pytest-xdist",
|
| 82 |
+
"pytest"
|
| 83 |
+
]
|
| 84 |
+
vllm = [
|
| 85 |
+
"vllm>=0.10.2,<0.13.0",
|
| 86 |
+
"fastapi",
|
| 87 |
+
"pydantic",
|
| 88 |
+
"requests",
|
| 89 |
+
"uvicorn"
|
| 90 |
+
]
|
| 91 |
+
vlm = [
|
| 92 |
+
"Pillow",
|
| 93 |
+
"torchvision",
|
| 94 |
+
"num2words==0.5.14"
|
| 95 |
+
]
|
| 96 |
+
math_verify = [
|
| 97 |
+
"math-verify>=0.5.2",
|
| 98 |
+
]
|
| 99 |
+
dev = [
|
| 100 |
+
# bco
|
| 101 |
+
"scikit-learn",
|
| 102 |
+
"joblib",
|
| 103 |
+
# deepspeed
|
| 104 |
+
"deepspeed>=0.14.4",
|
| 105 |
+
# judges
|
| 106 |
+
"openai>=1.23.2",
|
| 107 |
+
"llm-blender>=0.0.2",
|
| 108 |
+
# kernels
|
| 109 |
+
"kernels",
|
| 110 |
+
# liger
|
| 111 |
+
"liger-kernel>=0.6.4",
|
| 112 |
+
# peft
|
| 113 |
+
"peft>=0.8.0",
|
| 114 |
+
# quality
|
| 115 |
+
"pre-commit",
|
| 116 |
+
"hf-doc-builder",
|
| 117 |
+
# quantization
|
| 118 |
+
"bitsandbytes",
|
| 119 |
+
# scikit: included in bco
|
| 120 |
+
# test
|
| 121 |
+
"pytest-cov",
|
| 122 |
+
"pytest-datadir>=1.7.0", # lazy datadirs
|
| 123 |
+
"pytest-rerunfailures==15.1",
|
| 124 |
+
"pytest-xdist",
|
| 125 |
+
"pytest",
|
| 126 |
+
# vllm: not included in dev by default due to CUDA error; see GH-4228
|
| 127 |
+
# vlm
|
| 128 |
+
"Pillow",
|
| 129 |
+
"torchvision",
|
| 130 |
+
"num2words==0.5.14",
|
| 131 |
+
# for response parsing (required for training with tools)
|
| 132 |
+
"jmespath",
|
| 133 |
+
]
|
| 134 |
+
|
| 135 |
+
[tool.setuptools]
|
| 136 |
+
package-dir = {"trl" = "trl"}
|
| 137 |
+
|
| 138 |
+
[tool.setuptools.dynamic]
|
| 139 |
+
version = { file = "VERSION" }
|
| 140 |
+
|
| 141 |
+
[tool.coverage.run]
|
| 142 |
+
branch = true
|
| 143 |
+
|
| 144 |
+
[tool.ruff]
|
| 145 |
+
target-version = "py310"
|
| 146 |
+
line-length = 119
|
| 147 |
+
src = ["trl"]
|
| 148 |
+
|
| 149 |
+
[tool.ruff.lint]
|
| 150 |
+
ignore = [
|
| 151 |
+
"B028", # warning without explicit stacklevel
|
| 152 |
+
"C408", # dict() calls (stylistic)
|
| 153 |
+
"C901", # function complexity
|
| 154 |
+
"E501",
|
| 155 |
+
]
|
| 156 |
+
extend-select = ["E", "F", "I", "W", "UP", "B", "T", "C"]
|
| 157 |
+
|
| 158 |
+
[tool.ruff.lint.per-file-ignores]
|
| 159 |
+
# Allow prints in auxiliary scripts
|
| 160 |
+
"examples/**.py" = ["T201"]
|
| 161 |
+
"scripts/**.py" = ["T201"]
|
| 162 |
+
# Ignore import violations in all `__init__.py` files.
|
| 163 |
+
"__init__.py" = ["F401"]
|
| 164 |
+
|
| 165 |
+
[tool.ruff.lint.isort]
|
| 166 |
+
lines-after-imports = 2
|
| 167 |
+
known-first-party = ["trl"]
|
| 168 |
+
|
| 169 |
+
[tool.pytest.ini_options]
|
| 170 |
+
markers = [
|
| 171 |
+
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
|
| 172 |
+
"low_priority: marks tests as low priority (deselect with '-m \"not low_priority\"')"
|
| 173 |
+
]
|
| 174 |
+
norecursedirs = [
|
| 175 |
+
"tests/experimental",
|
| 176 |
+
]
|
| 177 |
+
filterwarnings = [
|
| 178 |
+
# SWIG deprecations from SWIG-generated C/C++ extensions: sentencepiece
|
| 179 |
+
# Upstream issue: https://github.com/google/sentencepiece/issues/1150
|
| 180 |
+
"ignore:builtin type SwigPyPacked has no __module__ attribute:DeprecationWarning",
|
| 181 |
+
"ignore:builtin type SwigPyObject has no __module__ attribute:DeprecationWarning",
|
| 182 |
+
"ignore:builtin type swigvarlink has no __module__ attribute:DeprecationWarning",
|
| 183 |
+
|
| 184 |
+
# PyTorch JIT deprecations (upstream, not actionable in TRL)
|
| 185 |
+
# Upstream issue: https://github.com/deepspeedai/DeepSpeed/issues/7835
|
| 186 |
+
"ignore:`torch.jit.script_method` is deprecated:DeprecationWarning",
|
| 187 |
+
"ignore:`torch.jit.script` is deprecated:DeprecationWarning",
|
| 188 |
+
|
| 189 |
+
# PyTorch DataLoader pin_memory device argument deprecations
|
| 190 |
+
# Triggered internally by torch.utils.data, not by our code
|
| 191 |
+
# Upstream issue: https://github.com/pytorch/pytorch/issues/174546
|
| 192 |
+
"ignore:The argument 'device' of Tensor.pin_memory:DeprecationWarning",
|
| 193 |
+
"ignore:The argument 'device' of Tensor.is_pinned:DeprecationWarning",
|
| 194 |
+
]
|
ICL/RL/trl_source/requirements.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
accelerate>=1.4.0
|
| 2 |
+
datasets>=3.0.0
|
| 3 |
+
transformers>=4.56.2
|