Lekr0 commited on
Commit
e9585fc
·
verified ·
1 Parent(s): 741f7c3

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. ICL/RL/trl_source/.github/PULL_REQUEST_TEMPLATE.md +31 -0
  2. ICL/RL/trl_source/assets/logo-dark.png +0 -0
  3. ICL/RL/trl_source/examples/README.md +3 -0
  4. ICL/RL/trl_source/examples/accelerate_configs/alst_ulysses_4gpu.yaml +45 -0
  5. ICL/RL/trl_source/examples/accelerate_configs/context_parallel_2gpu.yaml +30 -0
  6. ICL/RL/trl_source/examples/accelerate_configs/deepspeed_zero1.yaml +20 -0
  7. ICL/RL/trl_source/examples/accelerate_configs/deepspeed_zero2.yaml +21 -0
  8. ICL/RL/trl_source/examples/accelerate_configs/deepspeed_zero3.yaml +22 -0
  9. ICL/RL/trl_source/examples/accelerate_configs/fsdp1.yaml +28 -0
  10. ICL/RL/trl_source/examples/accelerate_configs/fsdp2.yaml +25 -0
  11. ICL/RL/trl_source/examples/accelerate_configs/multi_gpu.yaml +16 -0
  12. ICL/RL/trl_source/examples/accelerate_configs/single_gpu.yaml +16 -0
  13. ICL/RL/trl_source/examples/cli_configs/example_config.yaml +18 -0
  14. ICL/RL/trl_source/examples/datasets/deepmath_103k.py +98 -0
  15. ICL/RL/trl_source/examples/datasets/hh-rlhf-helpful-base.py +132 -0
  16. ICL/RL/trl_source/examples/datasets/llava_instruct_mix.py +118 -0
  17. ICL/RL/trl_source/examples/datasets/lm-human-preferences-descriptiveness.py +119 -0
  18. ICL/RL/trl_source/examples/datasets/lm-human-preferences-sentiment.py +112 -0
  19. ICL/RL/trl_source/examples/datasets/math_shepherd.py +169 -0
  20. ICL/RL/trl_source/examples/datasets/prm800k.py +156 -0
  21. ICL/RL/trl_source/examples/datasets/rlaif-v.py +112 -0
  22. ICL/RL/trl_source/examples/datasets/tldr.py +104 -0
  23. ICL/RL/trl_source/examples/datasets/tldr_preference.py +110 -0
  24. ICL/RL/trl_source/examples/datasets/ultrafeedback-prompt.py +102 -0
  25. ICL/RL/trl_source/examples/datasets/ultrafeedback.py +144 -0
  26. ICL/RL/trl_source/examples/notebooks/README.md +17 -0
  27. ICL/RL/trl_source/examples/notebooks/grpo_agent.ipynb +706 -0
  28. ICL/RL/trl_source/examples/notebooks/grpo_functiongemma_browsergym_openenv.ipynb +1914 -0
  29. ICL/RL/trl_source/examples/notebooks/grpo_ministral3_vl.ipynb +740 -0
  30. ICL/RL/trl_source/examples/notebooks/grpo_qwen3_vl.ipynb +693 -0
  31. ICL/RL/trl_source/examples/notebooks/grpo_rnj_1_instruct.ipynb +622 -0
  32. ICL/RL/trl_source/examples/notebooks/grpo_trl_lora_qlora.ipynb +1638 -0
  33. ICL/RL/trl_source/examples/notebooks/openenv_sudoku_grpo.ipynb +0 -0
  34. ICL/RL/trl_source/examples/notebooks/openenv_wordle_grpo.ipynb +0 -0
  35. ICL/RL/trl_source/examples/notebooks/sft_ministral3_vl.ipynb +0 -0
  36. ICL/RL/trl_source/examples/notebooks/sft_qwen_vl.ipynb +0 -0
  37. ICL/RL/trl_source/examples/notebooks/sft_trl_lora_qlora.ipynb +1140 -0
  38. ICL/RL/trl_source/examples/scripts/bco.py +173 -0
  39. ICL/RL/trl_source/examples/scripts/cpo.py +112 -0
  40. ICL/RL/trl_source/examples/scripts/dpo.py +17 -0
  41. ICL/RL/trl_source/examples/scripts/dpo_vlm.py +151 -0
  42. ICL/RL/trl_source/examples/scripts/gkd.py +149 -0
  43. ICL/RL/trl_source/examples/scripts/grpo_agent.py +326 -0
  44. ICL/RL/trl_source/examples/scripts/grpo_vlm.py +164 -0
  45. ICL/RL/trl_source/examples/scripts/gspo.py +137 -0
  46. ICL/RL/trl_source/examples/scripts/gspo_vlm.py +153 -0
  47. ICL/RL/trl_source/examples/scripts/kto.py +112 -0
  48. ICL/RL/trl_source/examples/scripts/mpo_vlm.py +142 -0
  49. ICL/RL/trl_source/examples/scripts/nash_md.py +153 -0
  50. ICL/RL/trl_source/examples/scripts/nemo_gym/README.md +5 -0
ICL/RL/trl_source/.github/PULL_REQUEST_TEMPLATE.md ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # What does this PR do?
2
+
3
+ <!--
4
+ Congratulations! You've made it this far! You're not quite done yet though.
5
+
6
+ Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution.
7
+
8
+ Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change.
9
+
10
+ Once you're done, someone will review your PR shortly. They may suggest changes to make the code even better.
11
+ -->
12
+
13
+ <!-- Remove if not applicable -->
14
+
15
+ Fixes # (issue)
16
+
17
+
18
+ ## Before submitting
19
+ - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
20
+ - [ ] Did you read the [contributor guideline](https://github.com/huggingface/trl/blob/main/CONTRIBUTING.md#create-a-pull-request),
21
+ Pull Request section?
22
+ - [ ] Was this discussed/approved via a GitHub issue? Please add a link
23
+ to it if that's the case.
24
+ - [ ] Did you make sure to update the documentation with your changes?
25
+ - [ ] Did you write any new necessary tests?
26
+
27
+
28
+ ## Who can review?
29
+
30
+ Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
31
+ members/contributors who may be interested in your PR.
ICL/RL/trl_source/assets/logo-dark.png ADDED
ICL/RL/trl_source/examples/README.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # Examples
2
+
3
+ Please check out https://huggingface.co/docs/trl/example_overview for documentation on our examples.
ICL/RL/trl_source/examples/accelerate_configs/alst_ulysses_4gpu.yaml ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ALST/Ulysses Sequence Parallelism with 2D Parallelism (DP + SP) for 4 GPUs
2
+ #
3
+ # This configuration enables 2D parallelism:
4
+ # - Sequence Parallelism (sp_size=2): Sequences split across 2 GPUs using ALST/Ulysses
5
+ # - Data Parallelism (dp_shard_size=2): Model/optimizer sharded across 2 GPUs
6
+ # - Total: 4 GPUs (2 × 2)
7
+ #
8
+ # Set parallelism_config in your training script:
9
+ # parallelism_config = ParallelismConfig(
10
+ # sp_backend="deepspeed",
11
+ # sp_size=2,
12
+ # dp_shard_size=2, # Calculated as: num_gpus // sp_size
13
+ # sp_handler=DeepSpeedSequenceParallelConfig(...)
14
+ # )
15
+
16
+ compute_environment: LOCAL_MACHINE
17
+ debug: false
18
+ deepspeed_config:
19
+ zero_stage: 3
20
+ seq_parallel_communication_data_type: bf16
21
+ offload_optimizer_device: none
22
+ offload_param_device: none
23
+ zero3_init_flag: true
24
+ zero3_save_16bit_model: true
25
+ distributed_type: DEEPSPEED
26
+ downcast_bf16: 'no'
27
+ machine_rank: 0
28
+ main_training_function: main
29
+ mixed_precision: bf16
30
+ num_machines: 1
31
+ num_processes: 4 # Total number of GPUs
32
+ rdzv_backend: static
33
+ same_network: true
34
+ tpu_env: []
35
+ tpu_use_cluster: false
36
+ tpu_use_sudo: false
37
+ use_cpu: false
38
+ parallelism_config:
39
+ parallelism_config_dp_replicate_size: 1
40
+ parallelism_config_dp_shard_size: 2 # Enables 2D parallelism with SP
41
+ parallelism_config_tp_size: 1
42
+ parallelism_config_sp_size: 2 # Sequence parallel size
43
+ parallelism_config_sp_backend: deepspeed
44
+ parallelism_config_sp_seq_length_is_variable: true
45
+ parallelism_config_sp_attn_implementation: flash_attention_2
ICL/RL/trl_source/examples/accelerate_configs/context_parallel_2gpu.yaml ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Context Parallelism with FSDP for 2 GPUs
2
+ compute_environment: LOCAL_MACHINE
3
+ debug: false
4
+ distributed_type: FSDP
5
+ downcast_bf16: 'no'
6
+ enable_cpu_affinity: false
7
+ fsdp_config:
8
+ fsdp_activation_checkpointing: true # Enable activation checkpointing for memory efficiency
9
+ fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
10
+ fsdp_cpu_ram_efficient_loading: true
11
+ fsdp_offload_params: false
12
+ fsdp_reshard_after_forward: true
13
+ fsdp_state_dict_type: FULL_STATE_DICT
14
+ fsdp_version: 2
15
+ machine_rank: 0
16
+ main_training_function: main
17
+ mixed_precision: bf16
18
+ num_machines: 1
19
+ num_processes: 2 # Number of GPUs
20
+ rdzv_backend: static
21
+ same_network: true
22
+ tpu_env: []
23
+ tpu_use_cluster: false
24
+ tpu_use_sudo: false
25
+ use_cpu: false
26
+ parallelism_config:
27
+ parallelism_config_dp_replicate_size: 1
28
+ parallelism_config_dp_shard_size: 1
29
+ parallelism_config_tp_size: 1
30
+ parallelism_config_cp_size: 2 # Context parallel size
ICL/RL/trl_source/examples/accelerate_configs/deepspeed_zero1.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ deepspeed_config:
4
+ deepspeed_multinode_launcher: standard
5
+ gradient_accumulation_steps: 1
6
+ zero3_init_flag: false
7
+ zero_stage: 1
8
+ distributed_type: DEEPSPEED
9
+ downcast_bf16: 'no'
10
+ machine_rank: 0
11
+ main_training_function: main
12
+ mixed_precision: 'bf16'
13
+ num_machines: 1
14
+ num_processes: 8
15
+ rdzv_backend: static
16
+ same_network: true
17
+ tpu_env: []
18
+ tpu_use_cluster: false
19
+ tpu_use_sudo: false
20
+ use_cpu: false
ICL/RL/trl_source/examples/accelerate_configs/deepspeed_zero2.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ deepspeed_config:
4
+ deepspeed_multinode_launcher: standard
5
+ offload_optimizer_device: none
6
+ offload_param_device: none
7
+ zero3_init_flag: false
8
+ zero_stage: 2
9
+ distributed_type: DEEPSPEED
10
+ downcast_bf16: 'no'
11
+ machine_rank: 0
12
+ main_training_function: main
13
+ mixed_precision: 'bf16'
14
+ num_machines: 1
15
+ num_processes: 8
16
+ rdzv_backend: static
17
+ same_network: true
18
+ tpu_env: []
19
+ tpu_use_cluster: false
20
+ tpu_use_sudo: false
21
+ use_cpu: false
ICL/RL/trl_source/examples/accelerate_configs/deepspeed_zero3.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ deepspeed_config:
4
+ deepspeed_multinode_launcher: standard
5
+ offload_optimizer_device: none
6
+ offload_param_device: none
7
+ zero3_init_flag: true
8
+ zero3_save_16bit_model: true
9
+ zero_stage: 3
10
+ distributed_type: DEEPSPEED
11
+ downcast_bf16: 'no'
12
+ machine_rank: 0
13
+ main_training_function: main
14
+ mixed_precision: bf16
15
+ num_machines: 1
16
+ num_processes: 8
17
+ rdzv_backend: static
18
+ same_network: true
19
+ tpu_env: []
20
+ tpu_use_cluster: false
21
+ tpu_use_sudo: false
22
+ use_cpu: false
ICL/RL/trl_source/examples/accelerate_configs/fsdp1.yaml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ distributed_type: FSDP
4
+ downcast_bf16: 'no'
5
+ enable_cpu_affinity: false
6
+ fsdp_config:
7
+ fsdp_activation_checkpointing: false
8
+ fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
9
+ fsdp_backward_prefetch: BACKWARD_PRE
10
+ fsdp_cpu_ram_efficient_loading: true
11
+ fsdp_forward_prefetch: true
12
+ fsdp_offload_params: false
13
+ fsdp_reshard_after_forward: FULL_SHARD
14
+ fsdp_state_dict_type: FULL_STATE_DICT
15
+ fsdp_sync_module_states: true
16
+ fsdp_use_orig_params: true
17
+ fsdp_version: 1
18
+ machine_rank: 0
19
+ main_training_function: main
20
+ mixed_precision: bf16
21
+ num_machines: 1
22
+ num_processes: 8
23
+ rdzv_backend: static
24
+ same_network: true
25
+ tpu_env: []
26
+ tpu_use_cluster: false
27
+ tpu_use_sudo: false
28
+ use_cpu: false
ICL/RL/trl_source/examples/accelerate_configs/fsdp2.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Requires accelerate 1.7.0 or higher
2
+ compute_environment: LOCAL_MACHINE
3
+ debug: false
4
+ distributed_type: FSDP
5
+ downcast_bf16: 'no'
6
+ enable_cpu_affinity: false
7
+ fsdp_config:
8
+ fsdp_activation_checkpointing: false
9
+ fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
10
+ fsdp_cpu_ram_efficient_loading: true
11
+ fsdp_offload_params: false
12
+ fsdp_reshard_after_forward: true
13
+ fsdp_state_dict_type: FULL_STATE_DICT
14
+ fsdp_version: 2
15
+ machine_rank: 0
16
+ main_training_function: main
17
+ mixed_precision: bf16
18
+ num_machines: 1
19
+ num_processes: 8
20
+ rdzv_backend: static
21
+ same_network: true
22
+ tpu_env: []
23
+ tpu_use_cluster: false
24
+ tpu_use_sudo: false
25
+ use_cpu: false
ICL/RL/trl_source/examples/accelerate_configs/multi_gpu.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ distributed_type: MULTI_GPU
4
+ downcast_bf16: 'no'
5
+ gpu_ids: all
6
+ machine_rank: 0
7
+ main_training_function: main
8
+ mixed_precision: 'bf16'
9
+ num_machines: 1
10
+ num_processes: 8
11
+ rdzv_backend: static
12
+ same_network: true
13
+ tpu_env: []
14
+ tpu_use_cluster: false
15
+ tpu_use_sudo: false
16
+ use_cpu: false
ICL/RL/trl_source/examples/accelerate_configs/single_gpu.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ distributed_type: "NO"
4
+ downcast_bf16: 'no'
5
+ gpu_ids: all
6
+ machine_rank: 0
7
+ main_training_function: main
8
+ mixed_precision: 'bf16'
9
+ num_machines: 1
10
+ num_processes: 1
11
+ rdzv_backend: static
12
+ same_network: true
13
+ tpu_env: []
14
+ tpu_use_cluster: false
15
+ tpu_use_sudo: false
16
+ use_cpu: false
ICL/RL/trl_source/examples/cli_configs/example_config.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This is an example configuration file of TRL CLI, you can use it for
2
+ # SFT like that: `trl sft --config config.yaml --output_dir test-sft`
3
+ # The YAML file supports environment variables by adding an `env` field
4
+ # as below
5
+
6
+ # env:
7
+ # CUDA_VISIBLE_DEVICES: 0
8
+
9
+ model_name_or_path:
10
+ Qwen/Qwen2.5-0.5B
11
+ dataset_name:
12
+ stanfordnlp/imdb
13
+ report_to:
14
+ none
15
+ learning_rate:
16
+ 0.0001
17
+ lr_scheduler_type:
18
+ cosine
ICL/RL/trl_source/examples/datasets/deepmath_103k.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020-2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass, field
16
+
17
+ from datasets import load_dataset
18
+ from huggingface_hub import ModelCard
19
+ from transformers import HfArgumentParser
20
+
21
+
22
+ @dataclass
23
+ class ScriptArguments:
24
+ r"""
25
+ Arguments for the script.
26
+
27
+ Args:
28
+ push_to_hub (`bool`, *optional*, defaults to `False`):
29
+ Whether to push the dataset to the Hugging Face Hub.
30
+ repo_id (`str`, *optional*, defaults to `"trl-lib/DeepMath-103K"`):
31
+ Hugging Face repository ID to push the dataset to.
32
+ dataset_num_proc (`int`, *optional*):
33
+ Number of workers to use for dataset processing.
34
+ """
35
+
36
+ push_to_hub: bool = field(
37
+ default=False,
38
+ metadata={"help": "Whether to push the dataset to the Hugging Face Hub."},
39
+ )
40
+ repo_id: str = field(
41
+ default="trl-lib/DeepMath-103K",
42
+ metadata={"help": "Hugging Face repository ID to push the dataset to."},
43
+ )
44
+ dataset_num_proc: int | None = field(
45
+ default=None,
46
+ metadata={"help": "Number of workers to use for dataset processing."},
47
+ )
48
+
49
+
50
+ def process_example(example):
51
+ solution = example["final_answer"]
52
+ if solution not in ["True", "False", "Yes", "No"]:
53
+ solution = f"${solution}$"
54
+ prompt = [{"role": "user", "content": example["question"]}]
55
+ return {"prompt": prompt, "solution": solution}
56
+
57
+
58
+ model_card = ModelCard("""
59
+ ---
60
+ tags: [trl]
61
+ ---
62
+
63
+ # DeepMath-103K Dataset
64
+
65
+ ## Summary
66
+
67
+ [DeepMath-103K](https://huggingface.co/datasets/zwhe99/DeepMath-103K) is meticulously curated to push the boundaries of mathematical reasoning in language models.
68
+
69
+ ## Data Structure
70
+
71
+ - **Format**: [Conversational](https://huggingface.co/docs/trl/main/dataset_formats#conversational)
72
+ - **Type**: [Prompt-only](https://huggingface.co/docs/trl/main/dataset_formats#prompt-only)
73
+
74
+ Column:
75
+ - `"prompt"`: The input question.
76
+ - `"solution"`: The solution to the math problem.
77
+
78
+ ## Generation script
79
+
80
+ The script used to generate this dataset can be found [here](https://github.com/huggingface/trl/blob/main/examples/datasets/deepmath_103k.py).
81
+ """)
82
+
83
+ if __name__ == "__main__":
84
+ parser = HfArgumentParser(ScriptArguments)
85
+ script_args = parser.parse_args_into_dataclasses()[0]
86
+
87
+ dataset = load_dataset("zwhe99/DeepMath-103K", split="train")
88
+
89
+ dataset = dataset.map(
90
+ process_example,
91
+ remove_columns=dataset.column_names,
92
+ num_proc=script_args.dataset_num_proc,
93
+ )
94
+ dataset = dataset.train_test_split(test_size=0.05, seed=42)
95
+
96
+ if script_args.push_to_hub:
97
+ dataset.push_to_hub(script_args.repo_id)
98
+ model_card.push_to_hub(script_args.repo_id, repo_type="dataset")
ICL/RL/trl_source/examples/datasets/hh-rlhf-helpful-base.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020-2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import re
16
+ from dataclasses import dataclass, field
17
+
18
+ from datasets import load_dataset
19
+ from huggingface_hub import ModelCard
20
+ from transformers import HfArgumentParser
21
+
22
+
23
+ @dataclass
24
+ class ScriptArguments:
25
+ r"""
26
+ Arguments for the script.
27
+
28
+ Args:
29
+ push_to_hub (`bool`, *optional*, defaults to `False`):
30
+ Whether to push the dataset to the Hugging Face Hub.
31
+ repo_id (`str`, *optional*, defaults to `"trl-lib/hh-rlhf-helpful-base"`):
32
+ Hugging Face repository ID to push the dataset to.
33
+ dataset_num_proc (`int`, *optional*):
34
+ Number of workers to use for dataset processing.
35
+ """
36
+
37
+ push_to_hub: bool = field(
38
+ default=False,
39
+ metadata={"help": "Whether to push the dataset to the Hugging Face Hub."},
40
+ )
41
+ repo_id: str = field(
42
+ default="trl-lib/hh-rlhf-helpful-base", metadata={"help": "Hugging Face repository ID to push the dataset to."}
43
+ )
44
+ dataset_num_proc: int | None = field(
45
+ default=None, metadata={"help": "Number of workers to use for dataset processing."}
46
+ )
47
+
48
+
49
+ def common_start(str1: str, str2: str) -> str:
50
+ # Zip the two strings and iterate over them together
51
+ common_chars = []
52
+ for c1, c2 in zip(str1, str2, strict=True):
53
+ if c1 == c2:
54
+ common_chars.append(c1)
55
+ else:
56
+ break
57
+ # Join the common characters and return as a string
58
+ return "".join(common_chars)
59
+
60
+
61
+ def extract_dialogue(example: str) -> list[dict[str, str]]:
62
+ # Extract the prompt, which corresponds to the common start of the chosen and rejected dialogues
63
+ prompt_text = common_start(example["chosen"], example["rejected"])
64
+
65
+ # The chosen and rejected may share a common start, so we need to remove the common part
66
+ if not prompt_text.endswith("\n\nAssistant: "):
67
+ prompt_text = prompt_text[: prompt_text.rfind("\n\nAssistant: ")] + "\n\nAssistant: "
68
+
69
+ # Extract the chosen and rejected lines
70
+ chosen_line = example["chosen"][len(prompt_text) :]
71
+ rejected_line = example["rejected"][len(prompt_text) :]
72
+
73
+ # Remove the generation prompt ("\n\nAssistant: ") from the prompt
74
+ prompt_text = prompt_text[: -len("\n\nAssistant: ")]
75
+
76
+ # Split the string at every occurrence of "Human: " or "Assistant: "
77
+ prompt_lines = re.split(r"(\n\nAssistant: |\n\nHuman: )", prompt_text)
78
+
79
+ # Remove the first element as it's empty
80
+ prompt_lines = prompt_lines[1:]
81
+
82
+ prompt = []
83
+ for idx in range(0, len(prompt_lines), 2):
84
+ role = "user" if prompt_lines[idx] == "\n\nHuman: " else "assistant"
85
+ content = prompt_lines[idx + 1]
86
+ prompt.append({"role": role, "content": content})
87
+
88
+ # Remove the prompt from the chosen and rejected dialogues
89
+ chosen = [{"role": "assistant", "content": chosen_line}]
90
+ rejected = [{"role": "assistant", "content": rejected_line}]
91
+
92
+ return {"prompt": prompt, "chosen": chosen, "rejected": rejected}
93
+
94
+
95
+ model_card = ModelCard("""
96
+ ---
97
+ tags: [trl]
98
+ ---
99
+
100
+ # HH-RLHF-Helpful-Base Dataset
101
+
102
+ ## Summary
103
+
104
+ The HH-RLHF-Helpful-Base dataset is a processed version of [Anthropic's HH-RLHF](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset, specifically curated to train models using the [TRL library](https://github.com/huggingface/trl) for preference learning and alignment tasks. It contains pairs of text samples, each labeled as either "chosen" or "rejected," based on human preferences regarding the helpfulness of the responses. This dataset enables models to learn human preferences in generating helpful responses, enhancing their ability to assist users effectively.
105
+
106
+ ## Data Structure
107
+
108
+ - **Format**: [Conversational](https://huggingface.co/docs/trl/main/dataset_formats#conversational)
109
+ - **Type**: [Preference](https://huggingface.co/docs/trl/main/dataset_formats#preference)
110
+
111
+ Columns:
112
+ - `"prompt"`: The user query.
113
+ - `"chosen"`: A response deemed helpful by human evaluators.
114
+ - `"rejected"`: A response considered less helpful or unhelpful.
115
+
116
+ This structure allows models to learn to prefer the _chosen_ response over the _rejected_ one, thereby aligning with human preferences in helpfulness.
117
+
118
+ ## Generation script
119
+
120
+ The script used to generate this dataset can be found [here](https://github.com/huggingface/trl/blob/main/examples/datasets/hh-rlhf-helpful-base.py).
121
+ """)
122
+
123
+ if __name__ == "__main__":
124
+ parser = HfArgumentParser(ScriptArguments)
125
+ script_args = parser.parse_args_into_dataclasses()[0]
126
+
127
+ dataset = load_dataset("Anthropic/hh-rlhf", data_dir="helpful-base")
128
+ dataset = dataset.map(extract_dialogue, num_proc=script_args.dataset_num_proc)
129
+
130
+ if script_args.push_to_hub:
131
+ dataset.push_to_hub(script_args.repo_id)
132
+ model_card.push_to_hub(script_args.repo_id, repo_type="dataset")
ICL/RL/trl_source/examples/datasets/llava_instruct_mix.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020-2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import ast
16
+ from dataclasses import dataclass, field
17
+
18
+ from datasets import load_dataset
19
+ from huggingface_hub import ModelCard
20
+ from transformers import HfArgumentParser
21
+
22
+
23
+ @dataclass
24
+ class ScriptArguments:
25
+ r"""
26
+ Arguments for the script.
27
+
28
+ Args:
29
+ push_to_hub (`bool`, *optional*, defaults to `False`):
30
+ Whether to push the dataset to the Hugging Face Hub.
31
+ repo_id (`str`, *optional*, defaults to `"trl-lib/llava-instruct-mix"`):
32
+ Hugging Face repository ID to push the dataset to.
33
+ dataset_num_proc (`int`, *optional*):
34
+ Number of workers to use for dataset processing.
35
+ """
36
+
37
+ push_to_hub: bool = field(
38
+ default=False,
39
+ metadata={"help": "Whether to push the dataset to the Hugging Face Hub."},
40
+ )
41
+ repo_id: str = field(
42
+ default="trl-lib/llava-instruct-mix",
43
+ metadata={"help": "Hugging Face repository ID to push the dataset to."},
44
+ )
45
+ dataset_num_proc: int | None = field(
46
+ default=None,
47
+ metadata={"help": "Number of workers to use for dataset processing."},
48
+ )
49
+
50
+
51
+ def process_example(example):
52
+ messages = []
53
+ for message in ast.literal_eval(example["conversations"]):
54
+ content = message["value"]
55
+ content = content.replace("<image>", "").strip()
56
+ role = "user" if message["from"] == "human" else "assistant"
57
+ messages.append({"role": role, "content": content})
58
+ return {"messages": messages, "images": [example["image"]]}
59
+
60
+
61
+ def filter_long_examples(example):
62
+ total_length = sum(len(msg["content"]) for msg in example["messages"])
63
+ return total_length <= 1000
64
+
65
+
66
+ def split_prompt_completion(example):
67
+ """
68
+ Splits the messages into a prompt and a completion. The last message is considered the completion.
69
+ """
70
+ assert len(example["messages"]) > 1
71
+ example["prompt"] = example["messages"][:-1]
72
+ example["completion"] = example["messages"][-1:]
73
+ return example
74
+
75
+
76
+ model_card = ModelCard("""
77
+ ---
78
+ tags: [trl]
79
+ ---
80
+
81
+ # LLaVA Instruct Mix
82
+
83
+ ## Summary
84
+
85
+ The LLaVA Instruct Mix dataset is a processed version of [LLaVA Instruct Mix](https://huggingface.co/datasets/theblackcat102/llava-instruct-mix).
86
+
87
+ ## Data Structure
88
+
89
+ - **Format**: [Conversational](https://huggingface.co/docs/trl/main/dataset_formats#conversational)
90
+ - **Type**: [Language-modeling](https://huggingface.co/docs/trl/main/dataset_formats#language-modeling)
91
+
92
+ Columns:
93
+ - `"images"`: The image associated with the text.
94
+ - `"prompt"`: A list of messages that form the context for the conversation.
95
+ - `"completion"`: The last message in the conversation, which is the model's response.
96
+
97
+ This structure allows models to learn from the context of the conversation, enhancing their understanding of how to generate descriptive text based on visual inputs.
98
+
99
+ ## Generation script
100
+
101
+ The script used to generate this dataset can be found [here](https://github.com/huggingface/trl/blob/main/examples/datasets/llava_instruct_mix.py).
102
+ """)
103
+
104
+ if __name__ == "__main__":
105
+ parser = HfArgumentParser(ScriptArguments)
106
+ script_args = parser.parse_args_into_dataclasses()[0]
107
+
108
+ dataset = load_dataset("theblackcat102/llava-instruct-mix", split="train", num_proc=script_args.dataset_num_proc)
109
+
110
+ dataset = dataset.map(
111
+ process_example, remove_columns=["conversations", "image"], num_proc=script_args.dataset_num_proc
112
+ )
113
+ dataset = dataset.filter(filter_long_examples, num_proc=script_args.dataset_num_proc)
114
+ dataset = dataset.map(split_prompt_completion, remove_columns=["messages"], num_proc=script_args.dataset_num_proc)
115
+
116
+ if script_args.push_to_hub:
117
+ dataset.push_to_hub(script_args.repo_id, num_proc=script_args.dataset_num_proc)
118
+ model_card.push_to_hub(script_args.repo_id, repo_type="dataset")
ICL/RL/trl_source/examples/datasets/lm-human-preferences-descriptiveness.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020-2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass, field
16
+
17
+ from datasets import load_dataset
18
+ from huggingface_hub import ModelCard
19
+ from transformers import AutoTokenizer, HfArgumentParser
20
+
21
+
22
+ @dataclass
23
+ class ScriptArguments:
24
+ r"""
25
+ Arguments for the script.
26
+
27
+ Args:
28
+ push_to_hub (`bool`, *optional*, defaults to `False`):
29
+ Whether to push the dataset to the Hugging Face Hub.
30
+ repo_id (`str`, *optional*, defaults to `"trl-lib/lm-human-preferences-descriptiveness"`):
31
+ Hugging Face repository ID to push the dataset to.
32
+ dataset_num_proc (`int`, *optional*):
33
+ Number of workers to use for dataset processing.
34
+ """
35
+
36
+ push_to_hub: bool = field(
37
+ default=False,
38
+ metadata={"help": "Whether to push the dataset to the Hugging Face Hub."},
39
+ )
40
+ repo_id: str = field(
41
+ default="trl-lib/lm-human-preferences-descriptiveness",
42
+ metadata={"help": "Hugging Face repository ID to push the dataset to."},
43
+ )
44
+ dataset_num_proc: int | None = field(
45
+ default=None,
46
+ metadata={"help": "Number of workers to use for dataset processing."},
47
+ )
48
+
49
+
50
+ # Edge cases handling: remove the cases where all samples are the same
51
+ def samples_not_all_same(example):
52
+ return not all(example["sample0"] == example[f"sample{j}"] for j in range(1, 4))
53
+
54
+
55
+ def to_prompt_completion(example, tokenizer):
56
+ prompt = tokenizer.decode(example["query"]).strip()
57
+ best_idx = example["best"]
58
+ chosen = tokenizer.decode(example[f"sample{best_idx}"])
59
+ for rejected_idx in range(4): # take the first rejected sample that is different from the chosen one
60
+ rejected = tokenizer.decode(example[f"sample{rejected_idx}"])
61
+ if chosen != rejected:
62
+ break
63
+ assert chosen != rejected
64
+ return {"prompt": prompt, "chosen": chosen, "rejected": rejected}
65
+
66
+
67
+ model_card = ModelCard("""
68
+ ---
69
+ tags: [trl]
70
+ ---
71
+
72
+ # LM-Human-Preferences-Descriptiveness Dataset
73
+
74
+ ## Summary
75
+
76
+ The LM-Human-Preferences-Descriptiveness dataset is a processed subset of [OpenAI's LM-Human-Preferences](https://github.com/openai/lm-human-preferences), focusing specifically on enhancing the descriptiveness of generated text. It contains pairs of text samples, each labeled as either "chosen" or "rejected," based on human preferences regarding the level of detail and vividness in the descriptions. This dataset enables models to learn human preferences in descriptive language, improving their ability to generate rich and engaging narratives.
77
+
78
+ ## Data Structure
79
+
80
+ - **Format**: [Standard](https://huggingface.co/docs/trl/main/dataset_formats#standard)
81
+ - **Type**: [Preference](https://huggingface.co/docs/trl/main/dataset_formats#preference)
82
+
83
+ Columns:
84
+ - `"prompt"`: The text sample.
85
+ - `"chosen"`: A version of the text with enhanced descriptiveness.
86
+ - `"rejected"`: A version of the text with less descriptiveness.
87
+
88
+ This structure allows models to learn to prefer the _chosen_ response over the _rejected_ one, thereby aligning with human preferences in descriptive language.
89
+
90
+ ## Generation script
91
+
92
+ The script used to generate this dataset can be found [here](https://github.com/huggingface/trl/blob/main/examples/datasets/lm-human-preferences-descriptiveness.py).
93
+ """)
94
+
95
+ if __name__ == "__main__":
96
+ parser = HfArgumentParser(ScriptArguments)
97
+ script_args = parser.parse_args_into_dataclasses()[0]
98
+
99
+ dataset = load_dataset(
100
+ "json",
101
+ data_files="https://openaipublic.blob.core.windows.net/lm-human-preferences/labels/descriptiveness/offline_5k.json",
102
+ split="train",
103
+ )
104
+
105
+ dataset = dataset.filter(samples_not_all_same, num_proc=script_args.dataset_num_proc)
106
+
107
+ dataset = dataset.map(
108
+ to_prompt_completion,
109
+ num_proc=script_args.dataset_num_proc,
110
+ remove_columns=["query", "sample0", "sample1", "sample2", "sample3", "best"],
111
+ fn_kwargs={"tokenizer": AutoTokenizer.from_pretrained("gpt2")},
112
+ )
113
+
114
+ # train_size taken from https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/launch.py#L79)
115
+ dataset = dataset.train_test_split(train_size=4992)
116
+
117
+ if script_args.push_to_hub:
118
+ dataset.push_to_hub(script_args.repo_id)
119
+ model_card.push_to_hub(script_args.repo_id, repo_type="dataset")
ICL/RL/trl_source/examples/datasets/lm-human-preferences-sentiment.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020-2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass, field
16
+
17
+ from datasets import load_dataset
18
+ from huggingface_hub import ModelCard
19
+ from transformers import AutoTokenizer, HfArgumentParser
20
+
21
+
22
+ @dataclass
23
+ class ScriptArguments:
24
+ r"""
25
+ Arguments for the script.
26
+
27
+ Args:
28
+ push_to_hub (`bool`, *optional*, defaults to `False`):
29
+ Whether to push the dataset to the Hugging Face Hub.
30
+ repo_id (`str`, *optional*, defaults to `"trl-lib/lm-human-preferences-sentiment"`):
31
+ Hugging Face repository ID to push the dataset to.
32
+ dataset_num_proc (`int`, *optional*):
33
+ Number of workers to use for dataset processing.
34
+ """
35
+
36
+ push_to_hub: bool = field(
37
+ default=False,
38
+ metadata={"help": "Whether to push the dataset to the Hugging Face Hub."},
39
+ )
40
+ repo_id: str = field(
41
+ default="trl-lib/lm-human-preferences-sentiment",
42
+ metadata={"help": "Hugging Face repository ID to push the dataset to."},
43
+ )
44
+ dataset_num_proc: int | None = field(
45
+ default=None,
46
+ metadata={"help": "Number of workers to use for dataset processing."},
47
+ )
48
+
49
+
50
+ def to_prompt_completion(example, tokenizer):
51
+ prompt = tokenizer.decode(example["query"]).strip()
52
+ best_idx = example["best"]
53
+ chosen = tokenizer.decode(example[f"sample{best_idx}"])
54
+ for rejected_idx in range(4): # take the first rejected sample that is different from the chosen one
55
+ rejected = tokenizer.decode(example[f"sample{rejected_idx}"])
56
+ if chosen != rejected:
57
+ break
58
+ assert chosen != rejected
59
+ return {"prompt": prompt, "chosen": chosen, "rejected": rejected}
60
+
61
+
62
+ model_card = ModelCard("""
63
+ ---
64
+ tags: [trl]
65
+ ---
66
+
67
+ # LM-Human-Preferences-Sentiment Dataset
68
+
69
+ ## Summary
70
+
71
+ The LM-Human-Preferences-Sentiment dataset is a processed subset of [OpenAI's LM-Human-Preferences](https://github.com/openai/lm-human-preferences), focusing specifically on sentiment analysis tasks. It contains pairs of text samples, each labeled as either "chosen" or "rejected," based on human preferences regarding the sentiment conveyed in the text. This dataset enables models to learn human preferences in sentiment expression, enhancing their ability to generate and evaluate text with desired emotional tones.
72
+
73
+ ## Data Structure
74
+
75
+ - **Format**: [Standard](https://huggingface.co/docs/trl/main/dataset_formats#standard)
76
+ - **Type**: [Preference](https://huggingface.co/docs/trl/main/dataset_formats#preference)
77
+
78
+ Columns:
79
+ - `"prompt"`: The text sample.
80
+ - `"chosen"`: A version of the text that conveys the desired sentiment.
81
+ - `"rejected"`: A version of the text that does not convey the desired sentiment.
82
+
83
+ This structure allows models to learn to prefer the _chosen_ response over the _rejected_ one, thereby aligning with human preferences in sentiment expression.
84
+
85
+ ## Generation script
86
+
87
+ The script used to generate this dataset can be found [here](https://github.com/huggingface/trl/blob/main/examples/datasets/lm-human-preferences-sentiment.py).
88
+ """)
89
+
90
+ if __name__ == "__main__":
91
+ parser = HfArgumentParser(ScriptArguments)
92
+ script_args = parser.parse_args_into_dataclasses()[0]
93
+
94
+ dataset = load_dataset(
95
+ "json",
96
+ data_files="https://openaipublic.blob.core.windows.net/lm-human-preferences/labels/sentiment/offline_5k.json",
97
+ split="train",
98
+ )
99
+
100
+ dataset = dataset.map(
101
+ to_prompt_completion,
102
+ num_proc=script_args.dataset_num_proc,
103
+ remove_columns=["query", "sample0", "sample1", "sample2", "sample3", "best"],
104
+ fn_kwargs={"tokenizer": AutoTokenizer.from_pretrained("gpt2")},
105
+ )
106
+
107
+ # train_size taken from https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/launch.py#L70)
108
+ dataset = dataset.train_test_split(train_size=4992)
109
+
110
+ if script_args.push_to_hub:
111
+ dataset.push_to_hub(script_args.repo_id)
112
+ model_card.push_to_hub(script_args.repo_id, repo_type="dataset")
ICL/RL/trl_source/examples/datasets/math_shepherd.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020-2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import re
16
+ from dataclasses import dataclass, field
17
+ from itertools import chain
18
+
19
+ from datasets import load_dataset
20
+ from huggingface_hub import ModelCard
21
+ from transformers import HfArgumentParser
22
+
23
+
24
+ @dataclass
25
+ class ScriptArguments:
26
+ r"""
27
+ Arguments for the script.
28
+
29
+ Args:
30
+ push_to_hub (`bool`, *optional*, defaults to `False`):
31
+ Whether to push the dataset to the Hugging Face Hub.
32
+ repo_id (`str`, *optional*, defaults to `"trl-lib/math_shepherd"`):
33
+ Hugging Face repository ID to push the dataset to.
34
+ dataset_num_proc (`int`, *optional*):
35
+ Number of workers to use for dataset processing.
36
+ """
37
+
38
+ push_to_hub: bool = field(
39
+ default=False,
40
+ metadata={"help": "Whether to push the dataset to the Hugging Face Hub."},
41
+ )
42
+ repo_id: str = field(
43
+ default="trl-lib/math_shepherd",
44
+ metadata={"help": "Hugging Face repository ID to push the dataset to."},
45
+ )
46
+ dataset_num_proc: int | None = field(
47
+ default=None,
48
+ metadata={"help": "Number of workers to use for dataset processing."},
49
+ )
50
+
51
+
52
+ def process_example(example):
53
+ # Replace "ки" with "ⶻ" so that the size of the "input" matches the size of the "label"
54
+ inputs = example["input"].replace("ки", "ⶻ")
55
+
56
+ # Find the indices of the "ⶻ" characters (that should match with the indexes of the "+" or "-" in the label)
57
+ indexes = [m.start() for m in re.finditer("ⶻ", inputs)]
58
+
59
+ # Sanity that all indexes are either "+" or "-"
60
+ assert all(example["label"][idx] in ["+", "-"] for idx in indexes)
61
+
62
+ # Get the labels
63
+ labels = [example["label"][idx] == "+" for idx in indexes]
64
+
65
+ # Split the inputs into steps (caution, the first step is missing here, it is the prompt)
66
+ steps = [inputs[i:j] for i, j in zip(chain([0], indexes), chain(indexes, [None]), strict=True)]
67
+
68
+ # Remove the last step (single ⶻ)
69
+ steps = steps[:-1]
70
+
71
+ # Get the prompt (first part) and completions (rest)
72
+ prompt = steps[0]
73
+ completions = steps[1:]
74
+
75
+ # Remove the heading "ⶻ" and the final whitespace from the completions
76
+ assert all(completion.startswith("ⶻ") for completion in completions)
77
+ completions = [completion[1:].strip() for completion in completions]
78
+
79
+ # At this point, we need to retrieve the first step from the prompt.
80
+ # First, we handle particular cases (annotation error) where we have a first label before the end of the prompt.
81
+ if prompt.startswith(
82
+ (
83
+ "Mr. Rocky",
84
+ "Parker",
85
+ "What is the smallest positive",
86
+ " The Myth",
87
+ "Let $\\mathbf{a}$",
88
+ "Find the arithmetic",
89
+ "Determine an ordered pair",
90
+ "Determine the ordered pair",
91
+ "At the Quill and Scroll stationery",
92
+ "Round to the nearest",
93
+ r"Calculate $\sqrt{10p}",
94
+ r"Simplify $\sqrt{28x}",
95
+ )
96
+ ):
97
+ # Some spotted datasets errors where there is an annotation in the prompt: we remove it
98
+ labels = labels[1:]
99
+
100
+ # Then we handle the general case: we get the first step from the prompt by looking for "Step 1:" or "step 1:" or
101
+ # (less common) "?".
102
+ elif "Step 1:" in prompt:
103
+ prompt, first_step = prompt.split("Step 1:")
104
+ first_step = "Step 1:" + first_step
105
+ completions = [first_step.strip()] + completions
106
+ elif "step 1:" in prompt:
107
+ prompt, first_step = prompt.split("step 1:")
108
+ first_step = "step 1:" + first_step
109
+ completions = [first_step.strip()] + completions
110
+ elif "?" in prompt:
111
+ prompt, first_step = prompt.split("?")
112
+ prompt = prompt + "?"
113
+ completions = [first_step.strip()] + completions
114
+ else:
115
+ raise ValueError(f"Prompt can't be processed: {prompt}")
116
+
117
+ # Strip the prompt
118
+ prompt = prompt.strip()
119
+
120
+ # Sanity check that the length of the completions is the same as the length of the labels
121
+ assert len(completions) == len(labels)
122
+
123
+ return {"prompt": prompt, "completions": completions, "labels": labels}
124
+
125
+
126
+ model_card = ModelCard("""
127
+ ---
128
+ tags: [trl]
129
+ ---
130
+
131
+ # Math-Shepherd Dataset
132
+
133
+ ## Summary
134
+
135
+ The Math-Shepherd dataset is a processed version of [Math-Shepherd dataset](peiyi9979/Math-Shepherd), designed to train models using the [TRL library](https://github.com/huggingface/trl) for stepwise supervision tasks. It provides step-by-step solutions to mathematical problems, enabling models to learn and verify each step of a solution, thereby enhancing their reasoning capabilities.
136
+
137
+ ## Data Structure
138
+
139
+ - **Format**: [Standard](https://huggingface.co/docs/trl/main/dataset_formats#standard)
140
+ - **Type**: [Stepwise supervision](https://huggingface.co/docs/trl/main/dataset_formats#stepwise-supervision)
141
+
142
+ Columns:
143
+ - `"prompt"`: The problem statement.
144
+ - `"completions"`: A list of reasoning steps generated to solve the problem.
145
+ - `"labels"`: A list of booleans or floats indicating the correctness of each corresponding reasoning step.
146
+
147
+ This structure allows models to learn the correctness of each step in a solution, facilitating improved reasoning and problem-solving abilities.
148
+
149
+ ## Generation script
150
+
151
+ The script used to generate this dataset can be found [here](https://github.com/huggingface/trl/blob/main/examples/datasets/math_shepherd.py).
152
+ """)
153
+
154
+ if __name__ == "__main__":
155
+ parser = HfArgumentParser(ScriptArguments)
156
+ script_args = parser.parse_args_into_dataclasses()[0]
157
+
158
+ dataset = load_dataset("peiyi9979/Math-Shepherd", split="train")
159
+
160
+ dataset = dataset.map(
161
+ process_example,
162
+ remove_columns=["input", "label", "task"],
163
+ num_proc=script_args.dataset_num_proc,
164
+ )
165
+ dataset = dataset.train_test_split(test_size=0.05, seed=42)
166
+
167
+ if script_args.push_to_hub:
168
+ dataset.push_to_hub(script_args.repo_id)
169
+ model_card.push_to_hub(script_args.repo_id, repo_type="dataset")
ICL/RL/trl_source/examples/datasets/prm800k.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020-2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass, field
16
+
17
+ from datasets import load_dataset
18
+ from huggingface_hub import ModelCard
19
+ from transformers import HfArgumentParser
20
+
21
+
22
+ @dataclass
23
+ class ScriptArguments:
24
+ r"""
25
+ Arguments for the script.
26
+
27
+ Args:
28
+ push_to_hub (`bool`, *optional*, defaults to `False`):
29
+ Whether to push the dataset to the Hugging Face Hub.
30
+ repo_id (`str`, *optional*, defaults to `"trl-lib/prm800k"`):
31
+ Hugging Face repository ID to push the dataset to.
32
+ dataset_num_proc (`int`, *optional*):
33
+ Number of workers to use for dataset processing.
34
+ """
35
+
36
+ push_to_hub: bool = field(
37
+ default=False,
38
+ metadata={"help": "Whether to push the dataset to the Hugging Face Hub."},
39
+ )
40
+ repo_id: str = field(
41
+ default="trl-lib/prm800k",
42
+ metadata={"help": "Hugging Face repository ID to push the dataset to."},
43
+ )
44
+ dataset_num_proc: int | None = field(
45
+ default=None,
46
+ metadata={"help": "Number of workers to use for dataset processing."},
47
+ )
48
+
49
+
50
+ def process_example(example):
51
+ outputs = []
52
+ prompt = example["question"]["problem"]
53
+
54
+ # Iterate through each step
55
+ previous_completions = []
56
+ previous_labels = []
57
+ for step in example["label"]["steps"]:
58
+ if step["completions"] is None and step["human_completion"] is None and step["chosen_completion"] is None:
59
+ # happens sometimes
60
+ break
61
+ # Loop through completions
62
+ for completion_idx, completion in enumerate(step["completions"]):
63
+ # For every completion that are not chosen, we are in a terminal state, so we can add it to the list of outputs.
64
+ if completion_idx != step["chosen_completion"]:
65
+ content = completion["text"]
66
+ completions = previous_completions[:] + [content]
67
+ label = completion["rating"] == 1
68
+ labels = previous_labels[:] + [label]
69
+ outputs.append({"prompt": prompt, "completions": completions, "labels": labels})
70
+
71
+ # Now, expand the previous completions and labels
72
+ if step["chosen_completion"] is not None:
73
+ chosen_completion = step["completions"][step["chosen_completion"]]
74
+ label = chosen_completion["rating"] == 1
75
+ elif step["human_completion"] is not None:
76
+ chosen_completion = step["human_completion"]
77
+ label = True
78
+ else:
79
+ break
80
+ content = chosen_completion["text"]
81
+ previous_completions.append(content)
82
+ previous_labels.append(label)
83
+
84
+ # Last step: we are in a terminal state, so we can add it to the list of outputs
85
+ outputs.append({"prompt": prompt, "completions": previous_completions, "labels": previous_labels})
86
+ return outputs
87
+
88
+
89
+ def process_batch(examples):
90
+ outputs = []
91
+ batch_size = len(examples["label"])
92
+ for idx in range(batch_size):
93
+ example = {k: v[idx] for k, v in examples.items()}
94
+ outputs.extend(process_example(example))
95
+ # list of dict to dict of list
96
+ outputs = {k: [v[k] for v in outputs] for k in outputs[0]}
97
+ return outputs
98
+
99
+
100
+ model_card = ModelCard("""
101
+ ---
102
+ tags: [trl]
103
+ ---
104
+
105
+ # PRM800K Dataset
106
+
107
+ ## Summary
108
+
109
+ The PRM800K dataset is a processed version of [OpenAI's PRM800K](https://github.com/openai/prm800k), designed to train models using the [TRL library](https://github.com/huggingface/trl) for stepwise supervision tasks. It contains 800,000 step-level correctness labels for model-generated solutions to problems from the MATH dataset. This dataset enables models to learn and verify each step of a solution, enhancing their reasoning capabilities.
110
+
111
+ ## Data Structure
112
+
113
+ - **Format**: [Standard](https://huggingface.co/docs/trl/main/dataset_formats#standard)
114
+ - **Type**: [Stepwise supervision](https://huggingface.co/docs/trl/main/dataset_formats#stepwise-supervision)
115
+
116
+ Columns:
117
+ - `"prompt"`: The problem statement.
118
+ - `"completions"`: A list of reasoning steps generated to solve the problem.
119
+ - `"labels"`: A list of booleans or floats indicating the correctness of each corresponding reasoning step.
120
+
121
+ This structure allows models to learn the correctness of each step in a solution, facilitating improved reasoning and problem-solving abilities.
122
+
123
+ ## Generation script
124
+
125
+ The script used to generate this dataset can be found [here](https://github.com/huggingface/trl/blob/main/examples/datasets/prm800k.py).
126
+ """)
127
+
128
+ if __name__ == "__main__":
129
+ parser = HfArgumentParser(ScriptArguments)
130
+ script_args = parser.parse_args_into_dataclasses()[0]
131
+
132
+ data_files = {
133
+ "train": "https://github.com/openai/prm800k/raw/refs/heads/main/prm800k/data/phase1_train.jsonl",
134
+ "test": "https://github.com/openai/prm800k/raw/refs/heads/main/prm800k/data/phase1_test.jsonl",
135
+ }
136
+ dataset = load_dataset("json", data_files=data_files)
137
+
138
+ dataset = dataset.map(
139
+ process_batch,
140
+ batched=True,
141
+ batch_size=10,
142
+ remove_columns=[
143
+ "labeler",
144
+ "timestamp",
145
+ "generation",
146
+ "is_quality_control_question",
147
+ "is_initial_screening_question",
148
+ "question",
149
+ "label",
150
+ ],
151
+ num_proc=script_args.dataset_num_proc,
152
+ )
153
+
154
+ if script_args.push_to_hub:
155
+ dataset.push_to_hub(script_args.repo_id)
156
+ model_card.push_to_hub(script_args.repo_id, repo_type="dataset")
ICL/RL/trl_source/examples/datasets/rlaif-v.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020-2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass, field
16
+
17
+ from datasets import features, load_dataset
18
+ from huggingface_hub import ModelCard
19
+ from transformers import HfArgumentParser
20
+
21
+
22
+ @dataclass
23
+ class ScriptArguments:
24
+ r"""
25
+ Arguments for the script.
26
+
27
+ Args:
28
+ push_to_hub (`bool`, *optional*, defaults to `False`):
29
+ Whether to push the dataset to the Hugging Face Hub.
30
+ repo_id (`str`, *optional*, defaults to `"trl-lib/rlaif-v"`):
31
+ Hugging Face repository ID to push the dataset to.
32
+ dataset_num_proc (`int`, *optional*):
33
+ Number of workers to use for dataset processing.
34
+ """
35
+
36
+ push_to_hub: bool = field(
37
+ default=False,
38
+ metadata={"help": "Whether to push the dataset to the Hugging Face Hub."},
39
+ )
40
+ repo_id: str = field(
41
+ default="trl-lib/rlaif-v",
42
+ metadata={"help": "Hugging Face repository ID to push the dataset to."},
43
+ )
44
+ dataset_num_proc: int | None = field(
45
+ default=None,
46
+ metadata={"help": "Number of workers to use for dataset processing."},
47
+ )
48
+
49
+
50
+ def to_conversational(example):
51
+ """
52
+ Convert prompt from "xxx" to [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "xxx"}]}]
53
+ and chosen and rejected from "xxx" to [{"role": "assistant", "content": [{"type": "text", "text": "xxx"}]}].
54
+ Images are wrapped into a list.
55
+ """
56
+ prompt = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": example["question"]}]}]
57
+ chosen = [{"role": "assistant", "content": [{"type": "text", "text": example["chosen"]}]}]
58
+ rejected = [{"role": "assistant", "content": [{"type": "text", "text": example["rejected"]}]}]
59
+ return {"prompt": prompt, "images": [example["image"]], "chosen": chosen, "rejected": rejected}
60
+
61
+
62
+ model_card = ModelCard("""
63
+ ---
64
+ tags: [trl]
65
+ ---
66
+
67
+ # RLAIF-V Dataset
68
+
69
+ ## Summary
70
+
71
+ The RLAIF-V dataset is a processed version of the [openbmb/RLAIF-V-Dataset](https://huggingface.co/datasets/openbmb/RLAIF-V-Dataset#dataset-card-for-rlaif-v-dataset), specifically curated to train vision-language models using the [TRL library](https://github.com/huggingface/trl) for preference learning tasks. It contains 83,132 high-quality comparison pairs, each comprising an image and two textual descriptions: one preferred and one rejected. This dataset enables models to learn human preferences in visual contexts, enhancing their ability to generate and evaluate image captions.
72
+
73
+ ## Data Structure
74
+
75
+ - **Format**: [Conversational](https://huggingface.co/docs/trl/main/dataset_formats#conversational)
76
+ - **Type**: [Preference](https://huggingface.co/docs/trl/main/dataset_formats#preference)
77
+
78
+ Columns:
79
+ - `"prompt"`: The task related to the image.
80
+ - `"images"`: The image.
81
+ - `"chosen"`: The preferred answer.
82
+ - `"rejected"`: An alternative answer that was not preferred.
83
+
84
+ This structure allows models to learn to prefer the _chosen_ response over the _rejected_ one, thereby aligning with human preferences in visual tasks.
85
+
86
+ ## Generation script
87
+
88
+ The script used to generate this dataset can be found [here](https://github.com/huggingface/trl/blob/main/examples/datasets/rlaif-v.py).
89
+ """)
90
+
91
+ if __name__ == "__main__":
92
+ parser = HfArgumentParser(ScriptArguments)
93
+ script_args = parser.parse_args_into_dataclasses()[0]
94
+
95
+ dataset = load_dataset("openbmb/RLAIF-V-Dataset", split="train")
96
+ dataset = dataset.map(
97
+ to_conversational,
98
+ num_proc=script_args.dataset_num_proc,
99
+ remove_columns=dataset.column_names,
100
+ writer_batch_size=128,
101
+ )
102
+
103
+ # Cast the images to Sequence[Image] to avoid bytes format
104
+ f = dataset.features
105
+ f["images"] = features.Sequence(features.Image(decode=True))
106
+ dataset = dataset.cast(f)
107
+
108
+ dataset = dataset.train_test_split(test_size=0.01, writer_batch_size=128)
109
+
110
+ if script_args.push_to_hub:
111
+ dataset.push_to_hub(script_args.repo_id)
112
+ model_card.push_to_hub(script_args.repo_id, repo_type="dataset")
ICL/RL/trl_source/examples/datasets/tldr.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020-2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass, field
16
+
17
+ from datasets import load_dataset
18
+ from huggingface_hub import ModelCard
19
+ from transformers import HfArgumentParser
20
+
21
+
22
+ @dataclass
23
+ class ScriptArguments:
24
+ r"""
25
+ Arguments for the script.
26
+
27
+ Args:
28
+ push_to_hub (`bool`, *optional*, defaults to `False`):
29
+ Whether to push the dataset to the Hugging Face Hub.
30
+ repo_id (`str`, *optional*, defaults to `"trl-lib/tldr"`):
31
+ Hugging Face repository ID to push the dataset to.
32
+ dataset_num_proc (`int`, *optional*):
33
+ Number of workers to use for dataset processing.
34
+ """
35
+
36
+ push_to_hub: bool = field(
37
+ default=False,
38
+ metadata={"help": "Whether to push the dataset to the Hugging Face Hub."},
39
+ )
40
+ repo_id: str = field(
41
+ default="trl-lib/tldr",
42
+ metadata={"help": "Hugging Face repository ID to push the dataset to."},
43
+ )
44
+ dataset_num_proc: int | None = field(
45
+ default=None,
46
+ metadata={"help": "Number of workers to use for dataset processing."},
47
+ )
48
+
49
+
50
+ def to_prompt_completion(example):
51
+ tldr_format_str = "SUBREDDIT: r/{subreddit}\n\nTITLE: {title}\n\nPOST: {post}\n\nTL;DR:"
52
+ prompt = tldr_format_str.format(subreddit=example["subreddit"], title=example["title"], post=example["post"])
53
+ completion = " " + example["summary"] # Add a space to separate the prompt from the completion
54
+ return {"prompt": prompt, "completion": completion}
55
+
56
+
57
+ model_card = ModelCard("""
58
+ ---
59
+ tags: [trl]
60
+ ---
61
+
62
+ # TL;DR Dataset
63
+
64
+ ## Summary
65
+
66
+ The TL;DR dataset is a processed version of Reddit posts, specifically curated to train models using the [TRL library](https://github.com/huggingface/trl) for summarization tasks. It leverages the common practice on Reddit where users append "TL;DR" (Too Long; Didn't Read) summaries to lengthy posts, providing a rich source of paired text data for training summarization models.
67
+
68
+ ## Data Structure
69
+
70
+ - **Format**: [Standard](https://huggingface.co/docs/trl/main/dataset_formats#standard)
71
+ - **Type**: [Prompt-completion](https://huggingface.co/docs/trl/main/dataset_formats#prompt-completion)
72
+
73
+ Columns:
74
+ - `"prompt"`: The unabridged Reddit post.
75
+ - `"completion"`: The concise "TL;DR" summary appended by the author.
76
+
77
+ This structure enables models to learn the relationship between detailed content and its abbreviated form, enhancing their summarization capabilities.
78
+
79
+ ## Generation script
80
+
81
+ The script used to generate this dataset can be found [here](https://github.com/huggingface/trl/blob/main/examples/datasets/tldr.py).
82
+ """)
83
+
84
+ if __name__ == "__main__":
85
+ parser = HfArgumentParser(ScriptArguments)
86
+ script_args = parser.parse_args_into_dataclasses()[0]
87
+
88
+ # Filtered reddit TL;DR dataset from https://github.com/openai/summarize-from-feedback?tab=readme-ov-file#reddit-tldr-dataset
89
+ data_files = {
90
+ "train": "https://openaipublic.blob.core.windows.net/summarize-from-feedback/datasets/tldr_3_filtered/train.jsonl",
91
+ "validation": "https://openaipublic.blob.core.windows.net/summarize-from-feedback/datasets/tldr_3_filtered/valid.jsonl",
92
+ "test": "https://openaipublic.blob.core.windows.net/summarize-from-feedback/datasets/tldr_3_filtered/test.jsonl",
93
+ }
94
+ dataset = load_dataset("json", data_files=data_files)
95
+
96
+ dataset = dataset.map(
97
+ to_prompt_completion,
98
+ num_proc=script_args.dataset_num_proc,
99
+ remove_columns=["id", "subreddit", "title", "post", "summary"],
100
+ )
101
+
102
+ if script_args.push_to_hub:
103
+ dataset.push_to_hub(script_args.repo_id)
104
+ model_card.push_to_hub(script_args.repo_id, repo_type="dataset")
ICL/RL/trl_source/examples/datasets/tldr_preference.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020-2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass, field
16
+
17
+ from datasets import load_dataset
18
+ from huggingface_hub import ModelCard
19
+ from transformers import HfArgumentParser
20
+
21
+
22
+ @dataclass
23
+ class ScriptArguments:
24
+ r"""
25
+ Arguments for the script.
26
+
27
+ Args:
28
+ push_to_hub (`bool`, *optional*, defaults to `False`):
29
+ Whether to push the dataset to the Hugging Face Hub.
30
+ repo_id (`str`, *optional*, defaults to `"trl-lib/tldr-preference"`):
31
+ Hugging Face repository ID to push the dataset to.
32
+ dataset_num_proc (`int`, *optional*):
33
+ Number of workers to use for dataset processing.
34
+ """
35
+
36
+ push_to_hub: bool = field(
37
+ default=False,
38
+ metadata={"help": "Whether to push the dataset to the Hugging Face Hub."},
39
+ )
40
+ repo_id: str = field(
41
+ default="trl-lib/tldr-preference",
42
+ metadata={"help": "Hugging Face repository ID to push the dataset to."},
43
+ )
44
+ dataset_num_proc: int | None = field(
45
+ default=None,
46
+ metadata={"help": "Number of workers to use for dataset processing."},
47
+ )
48
+
49
+
50
+ def to_preference(example):
51
+ info = example["info"]
52
+ if example["batch"] in ["batch0_cnndm", "cnndm0", "cnndm2"]: # CNN Daily Mail batches
53
+ article = info["article"].replace("\n\n", "\n")
54
+ prompt = f"TITLE: {info['title']}\n\n{article}\n\nTL;DR:"
55
+ elif example["batch"] in [f"batch{i}" for i in range(3, 23)] + ["edit_b2_eval_test"]: # Reddit batches
56
+ post = info["post"].replace("\n\n", "\n")
57
+ prompt = f"SUBREDDIT: r/{info['subreddit']}\n\nTITLE: {info['title']}\n\nPOST: {post}\n\nTL;DR:"
58
+ else:
59
+ raise ValueError(f"Unknown batch: {example['batch']}")
60
+
61
+ chosen_idx = example["choice"]
62
+ rejected_idx = 1 - chosen_idx
63
+ chosen = example["summaries"][chosen_idx]["text"]
64
+ rejected = example["summaries"][rejected_idx]["text"]
65
+ return {"prompt": prompt, "chosen": chosen, "rejected": rejected}
66
+
67
+
68
+ model_card = ModelCard("""
69
+ ---
70
+ tags: [trl]
71
+ ---
72
+
73
+ # TL;DR Dataset for Preference Learning
74
+
75
+ ## Summary
76
+
77
+ The TL;DR dataset is a processed version of Reddit posts, specifically curated to train models using the [TRL library](https://github.com/huggingface/trl) for preference learning and Reinforcement Learning from Human Feedback (RLHF) tasks. It leverages the common practice on Reddit where users append "TL;DR" (Too Long; Didn't Read) summaries to lengthy posts, providing a rich source of paired text data for training models to understand and generate concise summaries.
78
+
79
+ ## Data Structure
80
+
81
+ - **Format**: [Standard](https://huggingface.co/docs/trl/main/dataset_formats#standard)
82
+ - **Type**: [Preference](https://huggingface.co/docs/trl/main/dataset_formats#preference)
83
+
84
+ Columns:
85
+ - `"prompt"`: The unabridged Reddit post.
86
+ - `"chosen"`: The concise "TL;DR" summary appended by the author.
87
+ - `"rejected"`: An alternative summary or response that was not selected.
88
+
89
+ This structure enables models to learn the relationship between detailed content and its abbreviated form, enhancing their summarization capabilities.
90
+
91
+ ## Generation script
92
+
93
+ The script used to generate this dataset can be found [here](https://github.com/huggingface/trl/blob/main/examples/datasets/tldr_preference.py).
94
+ """)
95
+
96
+ if __name__ == "__main__":
97
+ parser = HfArgumentParser(ScriptArguments)
98
+ script_args = parser.parse_args_into_dataclasses()[0]
99
+
100
+ dataset = load_dataset("openai/summarize_from_feedback", "comparisons")
101
+
102
+ dataset = dataset.map(
103
+ to_preference,
104
+ num_proc=script_args.dataset_num_proc,
105
+ remove_columns=["info", "summaries", "choice", "worker", "batch", "split", "extra"],
106
+ )
107
+
108
+ if script_args.push_to_hub:
109
+ dataset.push_to_hub(script_args.repo_id)
110
+ model_card.push_to_hub(script_args.repo_id, repo_type="dataset")
ICL/RL/trl_source/examples/datasets/ultrafeedback-prompt.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020-2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass, field
16
+
17
+ from datasets import load_dataset
18
+ from huggingface_hub import ModelCard
19
+ from transformers import HfArgumentParser
20
+
21
+
22
+ @dataclass
23
+ class ScriptArguments:
24
+ r"""
25
+ Arguments for the script.
26
+
27
+ Args:
28
+ push_to_hub (`bool`, *optional*, defaults to `False`):
29
+ Whether to push the dataset to the Hugging Face Hub.
30
+ repo_id (`str`, *optional*, defaults to `"trl-lib/ultrafeedback-prompt"`):
31
+ Hugging Face repository ID to push the dataset to.
32
+ dataset_num_proc (`int`, *optional*):
33
+ Number of workers to use for dataset processing.
34
+ """
35
+
36
+ push_to_hub: bool = field(
37
+ default=False,
38
+ metadata={"help": "Whether to push the dataset to the Hugging Face Hub."},
39
+ )
40
+ repo_id: str = field(
41
+ default="trl-lib/ultrafeedback-prompt",
42
+ metadata={"help": "Hugging Face repository ID to push the dataset to."},
43
+ )
44
+ dataset_num_proc: int | None = field(
45
+ default=None,
46
+ metadata={"help": "Number of workers to use for dataset processing."},
47
+ )
48
+
49
+
50
+ def to_unpaired_preference(example):
51
+ prompt = [{"role": "user", "content": example["instruction"]}]
52
+ return {"prompt": prompt}
53
+
54
+
55
+ def drop_long_prompt(example):
56
+ if len(example["prompt"][0]["content"]) > 512:
57
+ return False
58
+ else:
59
+ return True
60
+
61
+
62
+ model_card = ModelCard("""
63
+ ---
64
+ tags: [trl]
65
+ ---
66
+
67
+ # UltraFeedback - Prompts Dataset
68
+
69
+ ## Summary
70
+
71
+ The UltraFeedback - Prompts dataset is a processed version of the [UltraFeedback](https://huggingface.co/datasets/openbmb/UltraFeedback) dataset for model evaluation on specific aspects like helpfulness, honesty, and instruction-following.
72
+
73
+ ## Data Structure
74
+
75
+ - **Format**: [Conversational](https://huggingface.co/docs/trl/main/dataset_formats#conversational)
76
+ - **Type**: [Prompt-only](https://huggingface.co/docs/trl/main/dataset_formats#prompt-only)
77
+
78
+ Column:
79
+ - `"prompt"`: The input question or instruction provided to the model.
80
+
81
+ ## Generation script
82
+
83
+ The script used to generate this dataset can be found [here](https://github.com/huggingface/trl/blob/main/examples/datasets/ultrafeedback-prompt.py).
84
+ """)
85
+
86
+ if __name__ == "__main__":
87
+ parser = HfArgumentParser(ScriptArguments)
88
+ script_args = parser.parse_args_into_dataclasses()[0]
89
+
90
+ dataset = load_dataset("openbmb/UltraFeedback", split="train")
91
+
92
+ dataset = dataset.map(
93
+ to_unpaired_preference,
94
+ remove_columns=["source", "instruction", "models", "completions", "correct_answers", "incorrect_answers"],
95
+ num_proc=script_args.dataset_num_proc,
96
+ )
97
+ dataset = dataset.filter(drop_long_prompt)
98
+ dataset = dataset.train_test_split(test_size=0.05, seed=42)
99
+
100
+ if script_args.push_to_hub:
101
+ dataset.push_to_hub(script_args.repo_id)
102
+ model_card.push_to_hub(script_args.repo_id, repo_type="dataset")
ICL/RL/trl_source/examples/datasets/ultrafeedback.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020-2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass, field
16
+
17
+ from datasets import load_dataset
18
+ from huggingface_hub import ModelCard
19
+ from transformers import HfArgumentParser
20
+
21
+
22
+ @dataclass
23
+ class ScriptArguments:
24
+ r"""
25
+ Arguments for the script.
26
+
27
+ Args:
28
+ model_name (`str`, *optional*, defaults to `"gpt-3.5-turbo"`):
29
+ Language model to target. Possible values are:
30
+ aspect (`str`, *optional*, defaults to `"helpfulness"`):
31
+ Aspect to target.
32
+ push_to_hub (`bool`, *optional*, defaults to `False`):
33
+ Whether to push the dataset to the Hugging Face Hub.
34
+ repo_id (`str`, *optional*, defaults to `"trl-lib/ultrafeedback-gpt-3.5-turbo-helpfulness"`):
35
+ Hugging Face repository ID to push the dataset to.
36
+ dataset_num_proc (`int`, *optional*):
37
+ Number of workers to use for dataset processing.
38
+ """
39
+
40
+ model_name: str = field(
41
+ default="gpt-3.5-turbo",
42
+ metadata={
43
+ "help": "Language model to target.",
44
+ "choices": [
45
+ "alpaca-7b",
46
+ "bard",
47
+ "falcon-40b-instruct",
48
+ "gpt-3.5-turbo",
49
+ "gpt-4",
50
+ "llama-2-13b-chat",
51
+ "llama-2-70b-chat",
52
+ "llama-2-7b-chat",
53
+ "mpt-30b-chat",
54
+ "pythia-12b",
55
+ "starchat",
56
+ "ultralm-13b",
57
+ "ultralm-65b",
58
+ "vicuna-33b",
59
+ "wizardlm-13b",
60
+ "wizardlm-70b",
61
+ "wizardlm-7b",
62
+ ],
63
+ },
64
+ )
65
+ aspect: str = field(
66
+ default="helpfulness",
67
+ metadata={
68
+ "help": "Aspect to target. Possible values are: 'helpfulness' (default), 'honesty', "
69
+ "'instruction-following', 'truthfulness'.",
70
+ "choices": ["helpfulness", "honesty", "instruction-following", "truthfulness"],
71
+ },
72
+ )
73
+ push_to_hub: bool = field(
74
+ default=False,
75
+ metadata={"help": "Whether to push the dataset to the Hugging Face Hub."},
76
+ )
77
+ repo_id: str = field(
78
+ default="trl-lib/ultrafeedback-gpt-3.5-turbo-helpfulness",
79
+ metadata={"help": "Hugging Face repository ID to push the dataset to."},
80
+ )
81
+ dataset_num_proc: int | None = field(
82
+ default=None,
83
+ metadata={"help": "Number of workers to use for dataset processing."},
84
+ )
85
+
86
+
87
+ def to_unpaired_preference(example, model_name, aspect):
88
+ prompt = [{"role": "user", "content": example["instruction"]}]
89
+ model_index = example["models"].index(model_name)
90
+ response_content = example["completions"][model_index]["response"]
91
+ completion = [{"role": "assistant", "content": response_content}]
92
+ score = int(example["completions"][model_index]["annotations"][aspect]["Rating"])
93
+ label = score >= 5
94
+ return {"prompt": prompt, "completion": completion, "label": label}
95
+
96
+
97
+ model_card = ModelCard("""
98
+ ---
99
+ tags: [trl]
100
+ ---
101
+
102
+ # UltraFeedback GPT-3.5-Turbo Helpfulness Dataset
103
+
104
+ ## Summary
105
+
106
+ The UltraFeedback GPT-3.5-Turbo Helpfulness dataset contains processed user-assistant interactions filtered for helpfulness, derived from the [openbmb/UltraFeedback](https://huggingface.co/datasets/openbmb/UltraFeedback) dataset. It is designed for fine-tuning and evaluating models in alignment tasks.
107
+
108
+ ## Data Structure
109
+
110
+ - **Format**: [Conversational](https://huggingface.co/docs/trl/main/dataset_formats#conversational)
111
+ - **Type**: [Unpaired preference](https://huggingface.co/docs/trl/main/dataset_formats#unpaired-preference)
112
+
113
+ Column:
114
+ - `"prompt"`: The input question or instruction provided to the model.
115
+ - `"completion"`: The model's response to the prompt.
116
+ - `"label"`: A binary value indicating whether the response is sufficiently helpful.
117
+
118
+ ## Generation script
119
+
120
+ The script used to generate this dataset can be found [here](https://github.com/huggingface/trl/blob/main/examples/datasets/ultrafeedback.py).
121
+ """)
122
+
123
+ if __name__ == "__main__":
124
+ parser = HfArgumentParser(ScriptArguments)
125
+ script_args = parser.parse_args_into_dataclasses()[0]
126
+
127
+ dataset = load_dataset("openbmb/UltraFeedback", split="train")
128
+
129
+ dataset = dataset.filter(
130
+ lambda example: script_args.model_name in example["models"],
131
+ batched=False,
132
+ num_proc=script_args.dataset_num_proc,
133
+ )
134
+ dataset = dataset.map(
135
+ to_unpaired_preference,
136
+ remove_columns=["source", "instruction", "models", "completions", "correct_answers", "incorrect_answers"],
137
+ fn_kwargs={"model_name": script_args.model_name, "aspect": script_args.aspect},
138
+ num_proc=script_args.dataset_num_proc,
139
+ )
140
+ dataset = dataset.train_test_split(test_size=0.05, seed=42)
141
+
142
+ if script_args.push_to_hub:
143
+ dataset.push_to_hub(script_args.repo_id)
144
+ model_card.push_to_hub(script_args.repo_id, repo_type="dataset")
ICL/RL/trl_source/examples/notebooks/README.md ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Notebooks
2
+
3
+ This directory contains a collection of Jupyter notebooks that demonstrate how to use the TRL library in different applications.
4
+
5
+ | Notebook | Description | Open in Colab |
6
+ | --- | --- | --- |
7
+ | [`grpo_trl_lora_qlora.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/grpo_trl_lora_qlora.ipynb) | GRPO using QLoRA on free Colab | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/notebooks/grpo_trl_lora_qlora.ipynb) |
8
+ | [`grpo_functiongemma_browsergym_openenv.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/grpo_functiongemma_browsergym_openenv.ipynb) | GRPO on FunctionGemma in the BrowserGym environment | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/notebooks/grpo_functiongemma_browsergym_openenv.ipynb) |
9
+ | [`grpo_agent.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/grpo_agent.ipynb) | GRPO for agent training | Not available due to OOM with Colab GPUs |
10
+ | [`grpo_rnj_1_instruct.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/grpo_rnj_1_instruct.ipynb) | GRPO rnj-1-instruct with QLoRA using TRL on Colab to add reasoning capabilities | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/notebooks/grpo_rnj_1_instruct.ipynb) |
11
+ | [`sft_ministral3_vl.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/sft_ministral3_vl.ipynb) | Supervised Fine-Tuning (SFT) Ministral 3 with QLoRA using TRL on free Colab | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/notebooks/sft_ministral3_vl.ipynb) |
12
+ | [`grpo_ministral3_vl.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/grpo_ministral3_vl.ipynb) | GRPO Ministral 3 with QLoRA using TRL on free Colab | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/notebooks/grpo_ministral3_vl.ipynb) |
13
+ | [`openenv_sudoku_grpo.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/openenv_sudoku_grpo.ipynb) | GRPO to play Sudoku on an OpenEnv environment | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/notebooks/openenv_sudoku_grpo.ipynb) |
14
+ | [`openenv_wordle_grpo.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/openenv_wordle_grpo.ipynb) | GRPO to play Worldle on an OpenEnv environment | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/notebooks/openenv_wordle_grpo.ipynb) |
15
+ | [`sft_trl_lora_qlora.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/sft_trl_lora_qlora.ipynb) | Supervised Fine-Tuning (SFT) using QLoRA on free Colab | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/notebooks/sft_trl_lora_qlora.ipynb) |
16
+ | [`sft_qwen_vl.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/sft_qwen_vl.ipynb) | Supervised Fine-Tuning (SFT) Qwen3-VL with QLoRA using TRL on free Colab | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/notebooks/sft_qwen_vl.ipynb) |
17
+ | [`grpo_qwen3_vl.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/grpo_qwen3_vl.ipynb) | GRPO Qwen3-VL with QLoRA using TRL on free Colab | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/notebooks/grpo_qwen3_vl.ipynb) |
ICL/RL/trl_source/examples/notebooks/grpo_agent.ipynb ADDED
@@ -0,0 +1,706 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "63ceecbc-87ad-4ad3-a317-f49267ffc93b",
6
+ "metadata": {},
7
+ "source": [
8
+ "# Agent Training with GRPO using TRL\n",
9
+ "\n",
10
+ "![trl banner](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl_banner_dark.png)\n",
11
+ "\n",
12
+ "\n",
13
+ "With [**Transformers Reinforcement Learning (TRL)**](https://github.com/huggingface/trl), you can train a language model to act as an **agent**. One that learns to reason, interact with external tools, and improve through reinforcement.\n",
14
+ "\n",
15
+ "- [TRL GitHub Repository](https://github.com/huggingface/trl) — star us to support the project! \n",
16
+ "- [Official TRL Examples](https://huggingface.co/docs/trl/example_overview) \n",
17
+ "- [Community Tutorials](https://huggingface.co/docs/trl/community_tutorials)\n",
18
+ "- [OpenEnv](https://github.com/meta-pytorch/OpenEnv)\n",
19
+ "\n",
20
+ "\n",
21
+ "TRL supports training agents that can use external tools as part of their decision process. \n",
22
+ "In this notebook, the agent has access to the **BioGRID database**, which it can query using **read-only SQL commands** to retrieve biological interaction data. The model learns when and how to use tools based on rewards.\n",
23
+ "\n",
24
+ "We'll fine-tune a model using GRPO (Group Relative Policy Optimization) via TRL. The agent will:\n",
25
+ "\n",
26
+ "1. Generate tool call to query the database if needed.\n",
27
+ "2. Receive the tool response and add it it to the context.\n",
28
+ "3. Learn to improve its tool usage and general capabilities over time through reward signals.\n",
29
+ "\n",
30
+ "## Install dependencies\n",
31
+ "\n",
32
+ "We'll start by installing **TRL**, which automatically includes the main dependencies like **Transformers**. \n",
33
+ "We'll also install **trackio** (for logging and monitoring training runs), **vLLM** (for efficient generation), and **jmespath** (needed for the tools capabilities)."
34
+ ]
35
+ },
36
+ {
37
+ "cell_type": "code",
38
+ "execution_count": null,
39
+ "id": "b4812fbf-3f61-481e-9a64-95277eada9c9",
40
+ "metadata": {},
41
+ "outputs": [],
42
+ "source": [
43
+ "!pip install -Uq \"trl[vllm]\" git+https://github.com/huggingface/transformers.git trackio jmespath "
44
+ ]
45
+ },
46
+ {
47
+ "cell_type": "markdown",
48
+ "id": "ede8e566-a1b5-460f-9fe8-a6010bc56148",
49
+ "metadata": {},
50
+ "source": [
51
+ "### Log in to Hugging Face\n",
52
+ "\n",
53
+ "Log in to your **Hugging Face** account to save your fine-tuned model, track your experiment results directly on the Hub or access gated models. You can find your **access token** on your [account settings page](https://huggingface.co/settings/tokens)."
54
+ ]
55
+ },
56
+ {
57
+ "cell_type": "code",
58
+ "execution_count": null,
59
+ "id": "21756ac0-78b2-495d-8137-28dfa9faae6a",
60
+ "metadata": {},
61
+ "outputs": [],
62
+ "source": [
63
+ "from huggingface_hub import notebook_login\n",
64
+ "\n",
65
+ "notebook_login()"
66
+ ]
67
+ },
68
+ {
69
+ "cell_type": "markdown",
70
+ "id": "KVGklspLYlmz",
71
+ "metadata": {},
72
+ "source": [
73
+ "## Create the database for the tool\n",
74
+ "\n",
75
+ "For this example, we will use the [BioGRID database](https://thebiogrid.org/), a curated resource containing **protein, genetic, and chemical interaction data**. We've already compiled and uploaded it to the Hub at [qgallouedec/biogrid](https://huggingface.co/datasets/qgallouedec/biogrid). The dataset is loaded and converted into an sqlite database.\n",
76
+ "\n",
77
+ "> 💡 We remove spaces in the column names to easen the model work. In real-world deployments, you may keep your original column names and rely on the agent to reason about them. Here, we simplify the schema to make training smoother."
78
+ ]
79
+ },
80
+ {
81
+ "cell_type": "code",
82
+ "execution_count": null,
83
+ "id": "rRzPMhfXBLkF",
84
+ "metadata": {},
85
+ "outputs": [],
86
+ "source": [
87
+ "import sqlite3\n",
88
+ "from datasets import load_dataset\n",
89
+ "\n",
90
+ "# Load dataset\n",
91
+ "biogrid_dataset = load_dataset(\"qgallouedec/biogrid\", split=\"train\")\n",
92
+ "df = biogrid_dataset.to_pandas()\n",
93
+ "\n",
94
+ "# Normalize column names: remove spaces, replace with underscores\n",
95
+ "df.columns = [c.replace(\" \", \"_\") for c in df.columns]\n",
96
+ "\n",
97
+ "# Save to SQLite\n",
98
+ "conn = sqlite3.connect(\"biogrid.db\")\n",
99
+ "try:\n",
100
+ " df.to_sql(\"interactions\", conn, if_exists=\"replace\", index=False)\n",
101
+ " print(f\"biogrid.db created. Rows stored: {len(df)}\")\n",
102
+ "finally:\n",
103
+ " conn.close()"
104
+ ]
105
+ },
106
+ {
107
+ "cell_type": "markdown",
108
+ "id": "pSSGvLbmZyC2",
109
+ "metadata": {},
110
+ "source": [
111
+ "## Load the QA dataset\n",
112
+ "\n",
113
+ "The training objective is to fine-tune a model to answer gene-related questions. The model should learn to use the database query tool to retrieve factual information when needed.\n",
114
+ "\n",
115
+ "We'll define a formatting function for each sample, adding instructions about the database and how to call it. The model must answer with **yes** or **no**. Let's implement the `format_example` function.\n",
116
+ "\n"
117
+ ]
118
+ },
119
+ {
120
+ "cell_type": "code",
121
+ "execution_count": null,
122
+ "id": "asrv7LbaD71C",
123
+ "metadata": {},
124
+ "outputs": [],
125
+ "source": [
126
+ "import textwrap\n",
127
+ "\n",
128
+ "def format_example(example):\n",
129
+ " question = example[\"question\"]\n",
130
+ " preamble = textwrap.dedent(\"\"\"\\\n",
131
+ " You have access to the BioGRID SQLite database.\n",
132
+ " Use SQL queries to retrieve only the information needed to answer the question.\n",
133
+ "\n",
134
+ " Genes may appear in the database in columns `Alt_IDs_Interactor_A` `Alt_IDs_Interactor_B`, `Aliases_Interactor_A` and `Aliases_Interactor_B`,\n",
135
+ " and each entry can contain multiple gene names or synonyms separated by '|', for example:\n",
136
+ " 'entrez gene/locuslink:JNKK(gene name synonym)|entrez gene/locuslink:MAPKK4(gene name synonym)|...'\n",
137
+ " So a gene like 'JNKK' or 'MAPKK4' may appear inside one of these strings.\n",
138
+ "\n",
139
+ " If the database schema is unclear or you are unsure about column names:\n",
140
+ " - First inspect the schema with `PRAGMA table_info(interactions);`\n",
141
+ " - Or preview a few rows with `SELECT * FROM interactions LIMIT 1;`\n",
142
+ "\n",
143
+ " Otherwise, directly query the required data.\n",
144
+ "\n",
145
+ " Final answer must be enclosed in stars, e.g. *Yes* or *No*.\n",
146
+ " Facts:\n",
147
+ " - The NCBI Taxonomy identifier for humans is taxid:9606.\n",
148
+ " \"\"\")\n",
149
+ " content = f\"{preamble}\\nQuestion: {question}\"\n",
150
+ " prompt = [{\"role\": \"user\", \"content\": content}]\n",
151
+ " return {\"prompt\": prompt}"
152
+ ]
153
+ },
154
+ {
155
+ "cell_type": "markdown",
156
+ "id": "UMnHXYZla_EO",
157
+ "metadata": {},
158
+ "source": [
159
+ "Now, let's load the database and call the previous function. \n",
160
+ "For simplicity, we will only use questions that start with **“Does the gene…”**. \n",
161
+ "In a real use case, the full dataset can be used.\n",
162
+ "\n",
163
+ "The QA dataset is available on the [Hub](https://huggingface.co/datasets/qgallouedec/biogrid_qa)."
164
+ ]
165
+ },
166
+ {
167
+ "cell_type": "code",
168
+ "execution_count": null,
169
+ "id": "jEs12KqwDnVl",
170
+ "metadata": {},
171
+ "outputs": [],
172
+ "source": [
173
+ "dataset = load_dataset(\"qgallouedec/biogrid_qa\", split=\"train\")\n",
174
+ "dataset = dataset.filter(\n",
175
+ " lambda example: example[\"question\"].startswith(\"Does the gene \")\n",
176
+ ") # keep only simple questions for example\n",
177
+ "dataset = dataset.map(format_example, remove_columns=[\"question\"])\n",
178
+ "\n",
179
+ "train_dataset = dataset\n",
180
+ "eval_dataset = None # No eval by default, can be added if needed"
181
+ ]
182
+ },
183
+ {
184
+ "cell_type": "markdown",
185
+ "id": "m4GRjbHycM5L",
186
+ "metadata": {},
187
+ "source": [
188
+ "## Create tool for the agent\n",
189
+ "\n",
190
+ "The `query_biogrid` function is the tool the model will use to query the database and retrieve factual information. \n",
191
+ "Each tool must be a standard Python function with **type-hinted arguments and return types**, and a **Google-style docstring** describing its purpose, parameters, and return value."
192
+ ]
193
+ },
194
+ {
195
+ "cell_type": "code",
196
+ "execution_count": null,
197
+ "id": "nLMH7hahGTyO",
198
+ "metadata": {},
199
+ "outputs": [],
200
+ "source": [
201
+ "from contextlib import contextmanager\n",
202
+ "import signal\n",
203
+ "\n",
204
+ "@contextmanager\n",
205
+ "def timeout(seconds):\n",
206
+ " \"\"\"Context manager that raises TimeoutError if execution exceeds time limit.\"\"\"\n",
207
+ "\n",
208
+ " def timeout_handler(signum, frame):\n",
209
+ " raise TimeoutError(f\"Operation timed out after {seconds} seconds\")\n",
210
+ "\n",
211
+ " signal.signal(signal.SIGALRM, timeout_handler)\n",
212
+ " signal.alarm(seconds)\n",
213
+ " try:\n",
214
+ " yield\n",
215
+ " finally:\n",
216
+ " signal.alarm(0)\n",
217
+ "\n",
218
+ "def query_biogrid(sql_command: str) -> list[tuple]:\n",
219
+ " \"\"\"\n",
220
+ " Execute a read-only SQL command on the BioGRID database.\n",
221
+ "\n",
222
+ " BioGRID is a curated biological database that compiles protein, genetic, and chemical interactions from multiple organisms. It provides researchers with experimentally verified interaction data to support studies in systems biology and functional genomics.\n",
223
+ "\n",
224
+ " Args:\n",
225
+ " sql_command: The SQL command to execute.\n",
226
+ "\n",
227
+ " Returns:\n",
228
+ " A list of tuples containing the query results.\n",
229
+ " \"\"\"\n",
230
+ " with timeout(5):\n",
231
+ " conn = sqlite3.connect(\"file:biogrid.db?mode=ro\", uri=True)\n",
232
+ " cursor = conn.cursor()\n",
233
+ " try:\n",
234
+ " cursor.execute(sql_command)\n",
235
+ " results = cursor.fetchall()\n",
236
+ " finally:\n",
237
+ " conn.close()\n",
238
+ " return results"
239
+ ]
240
+ },
241
+ {
242
+ "cell_type": "markdown",
243
+ "id": "GiHtooTwci3B",
244
+ "metadata": {},
245
+ "source": [
246
+ "## Define reward functions\n",
247
+ "\n",
248
+ "To guide the agent during training, we define a few simple reward functions:\n",
249
+ "\n",
250
+ "- **`query_reward`**: evaluates the model’s query strategy — penalizes more than two queries, penalizes generic database scans, and rewards use of `WHERE` and evidence supporting the final answer.\n",
251
+ "- **`correctness_reward`**: rewards Yes/No predictions that match the expected answer.\n",
252
+ "- **`structure_reward`**: rewards a proper assistant structure (tool call → response → optional explanation).\n",
253
+ "\n",
254
+ "Each function returns a list of floats used by the **GRPOTrainer** during optimization. \n",
255
+ "Combined, they encourage effective tool use and factual answers."
256
+ ]
257
+ },
258
+ {
259
+ "cell_type": "code",
260
+ "execution_count": null,
261
+ "id": "sXyqC6cJGe3L",
262
+ "metadata": {},
263
+ "outputs": [],
264
+ "source": [
265
+ "import re\n",
266
+ "\n",
267
+ "def query_reward(completions, answer, **kwargs):\n",
268
+ " \"\"\"\n",
269
+ " Reward query strategy:\n",
270
+ " - Penalize more than 2 queries\n",
271
+ " - Penalize generic queries (LIMIT 1 / PRAGMA)\n",
272
+ " - Reward usage of WHERE\n",
273
+ " - Reward evidence supporting the final answer\n",
274
+ " \"\"\"\n",
275
+ " rewards = []\n",
276
+ "\n",
277
+ " for completion, ans in zip(completions, answer, strict=False):\n",
278
+ " reward = 0.0\n",
279
+ " sql_queries = []\n",
280
+ " tool_results = []\n",
281
+ "\n",
282
+ " # collect all SQL queries and tool results\n",
283
+ " for turn in completion:\n",
284
+ " if turn.get(\"tool_calls\"):\n",
285
+ " for call in turn[\"tool_calls\"]:\n",
286
+ " sql = call[\"function\"][\"arguments\"].get(\"sql_command\", \"\").lower()\n",
287
+ " sql_queries.append(sql)\n",
288
+ " if turn.get(\"role\") == \"tool\" and turn.get(\"content\"):\n",
289
+ " tool_results.append(turn[\"content\"])\n",
290
+ "\n",
291
+ " # --- penalize too many queries ---\n",
292
+ " if len(sql_queries) > 3:\n",
293
+ " reward -= 1.5\n",
294
+ "\n",
295
+ " # --- check query quality ---\n",
296
+ " where_count = 0\n",
297
+ " for q in sql_queries:\n",
298
+ " if \"limit 1\" in q:\n",
299
+ " reward -= 1.0\n",
300
+ " if \" where \" not in q:\n",
301
+ " reward -= 0.5\n",
302
+ " else:\n",
303
+ " where_count += 1\n",
304
+ " reward += min(where_count, 3) * 0.4 # small bonus for WHERE usage\n",
305
+ "\n",
306
+ " # --- evidence check: do queries support the answer? ---\n",
307
+ " combined_results = []\n",
308
+ " error_detected = False\n",
309
+ "\n",
310
+ " for res in tool_results:\n",
311
+ " if isinstance(res, dict) and \"error\" in res:\n",
312
+ " error_detected = True\n",
313
+ " elif isinstance(res, list):\n",
314
+ " combined_results.extend(res)\n",
315
+ "\n",
316
+ " # if error detected, penalize heavily\n",
317
+ " if error_detected:\n",
318
+ " reward -= 2.0\n",
319
+ " elif len(sql_queries) == 0:\n",
320
+ " reward -= 1.5\n",
321
+ " else:\n",
322
+ " has_hits = len(combined_results) > 0\n",
323
+ " correct_answer = ans.lower()\n",
324
+ " if (has_hits and correct_answer == \"yes\") or (not has_hits and correct_answer == \"no\"):\n",
325
+ " reward += 2.0\n",
326
+ " else:\n",
327
+ " reward -= 1.5\n",
328
+ "\n",
329
+ " rewards.append(reward)\n",
330
+ "\n",
331
+ " return rewards\n",
332
+ "\n",
333
+ "\n",
334
+ "def correctness_reward(completions, answer, **kwargs):\n",
335
+ " \"\"\"\n",
336
+ " Reward Yes/No correctness.\n",
337
+ " Model must provide final answer enclosed in stars — *yes* or *no*.\n",
338
+ " Does not reward informal yes/no buried in text.\n",
339
+ " \"\"\"\n",
340
+ " rewards = []\n",
341
+ " for completion, ans in zip(completions, answer, strict=False):\n",
342
+ " raw = completion[-1][\"content\"].lower()\n",
343
+ "\n",
344
+ " # detect form *yes* or *no*\n",
345
+ " match = re.search(r\"\\*(yes|no)\\*\", raw)\n",
346
+ " guess = match.group(1) if match else None\n",
347
+ "\n",
348
+ " reward = 0.0\n",
349
+ "\n",
350
+ " if guess is None:\n",
351
+ " reward -= 0.5 # invalid format\n",
352
+ " elif guess == ans.lower():\n",
353
+ " reward += 0.6 # correct under required format\n",
354
+ " else:\n",
355
+ " reward -= 1.0 # wrong answer\n",
356
+ "\n",
357
+ " rewards.append(reward)\n",
358
+ "\n",
359
+ " return rewards\n",
360
+ "\n",
361
+ "\n",
362
+ "def structure_reward(completions, **kwargs):\n",
363
+ " \"\"\"\n",
364
+ " Reward proper assistant structure.\n",
365
+ " Encourages a logical sequence: tool call + response + optional extra content.\n",
366
+ " \"\"\"\n",
367
+ " rewards = []\n",
368
+ "\n",
369
+ " for completion in completions:\n",
370
+ " has_call = False\n",
371
+ " has_response = False\n",
372
+ " has_other = False\n",
373
+ "\n",
374
+ " for turn in completion:\n",
375
+ " role = turn.get(\"role\")\n",
376
+ " if role == \"assistant\" and turn.get(\"tool_calls\"):\n",
377
+ " has_call = True\n",
378
+ " elif role == \"tool\":\n",
379
+ " has_response = True\n",
380
+ " else:\n",
381
+ " content = turn.get(\"content\")\n",
382
+ " if content and content.strip() not in [\"\", \"<think>\"]:\n",
383
+ " has_other = True\n",
384
+ "\n",
385
+ " # Reward sequences\n",
386
+ " if has_call and has_response:\n",
387
+ " if has_other:\n",
388
+ " reward = 0.1\n",
389
+ " else:\n",
390
+ " reward = 0.05 # still positive even without extra text\n",
391
+ " elif has_call and not has_response:\n",
392
+ " reward = -0.15\n",
393
+ " else:\n",
394
+ " reward = 0.0 # neutral if no call\n",
395
+ "\n",
396
+ " rewards.append(reward)\n",
397
+ "\n",
398
+ " return rewards\n"
399
+ ]
400
+ },
401
+ {
402
+ "cell_type": "markdown",
403
+ "id": "zcgkrKtTb4T9",
404
+ "metadata": {},
405
+ "source": [
406
+ "## Set GRPO Config\n",
407
+ "\n",
408
+ "Next, we define the **GRPOConfig**, which controls the main training parameters. \n",
409
+ "This configuration specifies how the model interacts with **vLLM**, manages memory, and logs results."
410
+ ]
411
+ },
412
+ {
413
+ "cell_type": "code",
414
+ "execution_count": null,
415
+ "id": "t4ifJsNLElIN",
416
+ "metadata": {},
417
+ "outputs": [],
418
+ "source": [
419
+ "from trl import GRPOConfig\n",
420
+ "\n",
421
+ "output_dir = \"grpo_biogrid_qwen_3g-1.7b\"\n",
422
+ "\n",
423
+ "grpo_config = GRPOConfig(\n",
424
+ " # Training schedule / optimization\n",
425
+ " max_steps=400, # Max number of training steps\n",
426
+ " chat_template_kwargs = {\"enable_thinking\": False}, # Disable thinking to reduce token generation\n",
427
+ "\n",
428
+ " # GRPO configuration\n",
429
+ " max_completion_length = 1024, # Maximum tokens generated per model response\n",
430
+ "\n",
431
+ " # vLLM configuration\n",
432
+ " use_vllm = True, # Enable vLLM for faster inference during rollouts\n",
433
+ " vllm_mode = \"colocate\", # Run vLLM in colocate mode (same process as training)\n",
434
+ " vllm_enable_sleep_mode=False,\n",
435
+ "\n",
436
+ " # Logging / reporting\n",
437
+ " output_dir = output_dir, # Directory for checkpoints and logs\n",
438
+ " report_to=\"trackio\", # Experiment tracking tool (integrates with HF Spaces)\n",
439
+ " trackio_space_id = output_dir, # HF Space where experiment tracking will be saved\n",
440
+ " save_steps = 10, # Interval for saving checkpoints\n",
441
+ " log_completions = True,\n",
442
+ "\n",
443
+ " # Memory optimization\n",
444
+ " gradient_checkpointing = True, # Enable activation recomputation to save memory\n",
445
+ "\n",
446
+ " # Hub integration\n",
447
+ " push_to_hub = True, # Set True to automatically push model to Hugging Face Hub\n",
448
+ ")"
449
+ ]
450
+ },
451
+ {
452
+ "cell_type": "markdown",
453
+ "id": "34I-Q2MJuf42",
454
+ "metadata": {},
455
+ "source": [
456
+ "## Create `GRPOTrainer` and Start Training\n",
457
+ "\n",
458
+ "Next, we initialize the **`GRPOTrainer`**, which handles the full reinforcement learning loop.\n",
459
+ "\n",
460
+ "It receives the model name, reward functions, tool(s), and dataset defined earlier. \n",
461
+ "\n",
462
+ "Finally, we call `trainer.train()` to begin fine-tuning, allowing the model to learn how to query the database effectively through iterative feedback."
463
+ ]
464
+ },
465
+ {
466
+ "cell_type": "code",
467
+ "execution_count": null,
468
+ "id": "IysntAUOFvRn",
469
+ "metadata": {},
470
+ "outputs": [],
471
+ "source": [
472
+ "from trl import GRPOTrainer\n",
473
+ "\n",
474
+ "model_name=\"Qwen/Qwen3-1.7B\"\n",
475
+ "\n",
476
+ "trainer = GRPOTrainer(\n",
477
+ " model=model_name,\n",
478
+ " train_dataset=train_dataset,\n",
479
+ " eval_dataset=eval_dataset,\n",
480
+ " tools=[query_biogrid],\n",
481
+ " reward_funcs=[correctness_reward, structure_reward, query_reward],\n",
482
+ " args=grpo_config,\n",
483
+ ")"
484
+ ]
485
+ },
486
+ {
487
+ "cell_type": "markdown",
488
+ "id": "r_qJ5UwLuzCG",
489
+ "metadata": {},
490
+ "source": [
491
+ "Show memory stats before training"
492
+ ]
493
+ },
494
+ {
495
+ "cell_type": "code",
496
+ "execution_count": null,
497
+ "id": "DusT8JUaGmA6",
498
+ "metadata": {},
499
+ "outputs": [],
500
+ "source": [
501
+ "import torch\n",
502
+ "gpu_stats = torch.cuda.get_device_properties(0)\n",
503
+ "start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n",
504
+ "max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)\n",
505
+ "\n",
506
+ "print(f\"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.\")\n",
507
+ "print(f\"{start_gpu_memory} GB of memory reserved.\")"
508
+ ]
509
+ },
510
+ {
511
+ "cell_type": "markdown",
512
+ "id": "OTPkiz3fu0lp",
513
+ "metadata": {},
514
+ "source": [
515
+ "And train!"
516
+ ]
517
+ },
518
+ {
519
+ "cell_type": "code",
520
+ "execution_count": null,
521
+ "id": "NwI3buPOFMFk",
522
+ "metadata": {},
523
+ "outputs": [],
524
+ "source": [
525
+ "trainer_stats = trainer.train()"
526
+ ]
527
+ },
528
+ {
529
+ "cell_type": "markdown",
530
+ "id": "ITnLBLcTu2-p",
531
+ "metadata": {},
532
+ "source": [
533
+ "Show memory stats after training"
534
+ ]
535
+ },
536
+ {
537
+ "cell_type": "code",
538
+ "execution_count": null,
539
+ "id": "ftek6m4-GncK",
540
+ "metadata": {},
541
+ "outputs": [],
542
+ "source": [
543
+ "used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n",
544
+ "used_memory_for_lora = round(used_memory - start_gpu_memory, 3)\n",
545
+ "used_percentage = round(used_memory / max_memory * 100, 3)\n",
546
+ "lora_percentage = round(used_memory_for_lora / max_memory * 100, 3)\n",
547
+ "\n",
548
+ "print(f\"{trainer_stats.metrics['train_runtime']} seconds used for training.\")\n",
549
+ "print(f\"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training.\")\n",
550
+ "print(f\"Peak reserved memory = {used_memory} GB.\")\n",
551
+ "print(f\"Peak reserved memory for training = {used_memory_for_lora} GB.\")\n",
552
+ "print(f\"Peak reserved memory % of max memory = {used_percentage} %.\")\n",
553
+ "print(f\"Peak reserved memory for training % of max memory = {lora_percentage} %.\")"
554
+ ]
555
+ },
556
+ {
557
+ "cell_type": "markdown",
558
+ "id": "O6LAwznKu7mc",
559
+ "metadata": {},
560
+ "source": [
561
+ "Let's save the trained model."
562
+ ]
563
+ },
564
+ {
565
+ "cell_type": "code",
566
+ "execution_count": null,
567
+ "id": "idVgnNS1MWPr",
568
+ "metadata": {},
569
+ "outputs": [],
570
+ "source": [
571
+ "trainer.save_model(output_dir)\n",
572
+ "trainer.push_to_hub()"
573
+ ]
574
+ },
575
+ {
576
+ "cell_type": "markdown",
577
+ "id": "707318cb",
578
+ "metadata": {},
579
+ "source": [
580
+ "## Load the fine-tuned model and run inference using `smolagents`\n",
581
+ "\n",
582
+ "After fine-tuning the model with **GRPO (TRL)** for tool calling, we can test it at inference time using **`smolagents`**, a lightweight library for running multi-step agents.\n",
583
+ "\n",
584
+ "`smolagents` handles the agent loop for us:\n",
585
+ "- Detecting tool calls generated by the model\n",
586
+ "- Executing the corresponding tools (e.g. database queries)\n",
587
+ "- Feeding the results back to the model until a final answer is produced\n",
588
+ "\n",
589
+ "> **Note** \n",
590
+ "> Using an agent framework is optional. The fine-tuned model can also be used directly with `transformers` by manually controlling the inference loop and executing the tools outside the model.\n",
591
+ "> Agent frameworks are especially useful when the number of steps or tool calls is not fixed.\n",
592
+ "\n",
593
+ "We start by installing the required package:\n"
594
+ ]
595
+ },
596
+ {
597
+ "cell_type": "code",
598
+ "execution_count": null,
599
+ "id": "aab7fd5c",
600
+ "metadata": {},
601
+ "outputs": [],
602
+ "source": [
603
+ "!pip install git+https://github.com/huggingface/smolagents.git"
604
+ ]
605
+ },
606
+ {
607
+ "cell_type": "markdown",
608
+ "id": "24453572",
609
+ "metadata": {},
610
+ "source": [
611
+ "We will use the `CodeAgent` class from `smolagents` to instantiate our agent. \n",
612
+ "First, we need to define the tool the agent can use. This is done using the `@tool` decorator.\n",
613
+ "\n",
614
+ "As shown below, the tool definition is **exactly the same** as the one used during GRPO training with TRL. This consistency is important: the model was trained to emit calls following this schema, and at inference time the agent simply executes the corresponding Python function."
615
+ ]
616
+ },
617
+ {
618
+ "cell_type": "code",
619
+ "execution_count": null,
620
+ "id": "adcbbafa",
621
+ "metadata": {},
622
+ "outputs": [],
623
+ "source": [
624
+ "from smolagents import tool\n",
625
+ "\n",
626
+ "@tool\n",
627
+ "def query_biogrid(sql_command: str) -> list[tuple]:\n",
628
+ " \"\"\"\n",
629
+ " Execute a read-only SQL query on the BioGRID database.\n",
630
+ "\n",
631
+ " BioGRID is a curated biological database that compiles protein, genetic,\n",
632
+ " and chemical interactions from multiple organisms.\n",
633
+ "\n",
634
+ " Args:\n",
635
+ " sql_command: A read-only SQL query to execute.\n",
636
+ "\n",
637
+ " Returns:\n",
638
+ " A list of tuples containing the query results.\n",
639
+ " \"\"\"\n",
640
+ " with timeout(5):\n",
641
+ " conn = sqlite3.connect(\n",
642
+ " \"file:biogrid.db?mode=ro\",\n",
643
+ " uri=True,\n",
644
+ " )\n",
645
+ " cursor = conn.cursor()\n",
646
+ " try:\n",
647
+ " cursor.execute(sql_command)\n",
648
+ " results = cursor.fetchall()\n",
649
+ " finally:\n",
650
+ " conn.close()\n",
651
+ "\n",
652
+ " return results"
653
+ ]
654
+ },
655
+ {
656
+ "cell_type": "markdown",
657
+ "id": "59721ad2",
658
+ "metadata": {},
659
+ "source": [
660
+ "Now we can instantiate the agent using our fine-tuned model and the database tool defined above.\n",
661
+ "We wrap the model with `TransformersModel` and pass both the model and the tool when creating the `CodeAgent`."
662
+ ]
663
+ },
664
+ {
665
+ "cell_type": "code",
666
+ "execution_count": null,
667
+ "id": "e9ed8d00",
668
+ "metadata": {},
669
+ "outputs": [],
670
+ "source": [
671
+ "from smolagents import TransformersModel, CodeAgent\n",
672
+ "\n",
673
+ "model = TransformersModel(model_id=\"sergiopaniego/grpo_biogrid_qwen_3g-1.7b\", apply_chat_template_kwargs={\"enable_thinking\": False})\n",
674
+ "\n",
675
+ "# Create an agent with query_biogrid as tool\n",
676
+ "agent = CodeAgent(tools=[query_biogrid], model=model)"
677
+ ]
678
+ },
679
+ {
680
+ "cell_type": "markdown",
681
+ "id": "57ba9462",
682
+ "metadata": {},
683
+ "source": [
684
+ "Finally, we run the agent by passing the full prompt (including the instruction preamble and the question), exactly as it was used during training. This ensures the agent operates under the same context and assumptions learned with GRPO, allowing it to correctly decide when to query the database and how to format the final answer."
685
+ ]
686
+ },
687
+ {
688
+ "cell_type": "code",
689
+ "execution_count": null,
690
+ "id": "23a3cdf4",
691
+ "metadata": {},
692
+ "outputs": [],
693
+ "source": [
694
+ "result = agent.run(train_dataset[0]['prompt'][0]['content'])\n",
695
+ "print(result)"
696
+ ]
697
+ }
698
+ ],
699
+ "metadata": {
700
+ "language_info": {
701
+ "name": "python"
702
+ }
703
+ },
704
+ "nbformat": 4,
705
+ "nbformat_minor": 5
706
+ }
ICL/RL/trl_source/examples/notebooks/grpo_functiongemma_browsergym_openenv.ipynb ADDED
@@ -0,0 +1,1914 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "id": "lSR2nwdJg962"
7
+ },
8
+ "source": [
9
+ "# Fine-Tune FunctionGemma using Hugging Face TRL and OpenEnv\n",
10
+ "\n",
11
+ "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/notebooks/grpo_functiongemma_browsergym_openenv.ipynb)\n",
12
+ "\n",
13
+ "![trl banner](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl_banner_dark.png)\n",
14
+ "\n",
15
+ "This guide describes the process of fine-tuning [FunctionGemma](https://huggingface.co/google/functiongemma-270m-it) by Google DeepMind in the [BrowserGym](https://meta-pytorch.org/OpenEnv/environments/browsergym/) environment provided by OpenEnv, using Hugging Face TRL. The steps covered include:\n",
16
+ "\n",
17
+ "* What is GRPO and OpenEnv\n",
18
+ "* Setup dependencies for training\n",
19
+ "* Initialize the OpenEnv's BrowserGym environment\n",
20
+ "* Create rollout function with helpers\n",
21
+ "* Define the reward functions\n",
22
+ "* Load the custom dataset\n",
23
+ "* Fine tune using TRL and the GRPOTrainer\n",
24
+ "* Load the fine-tuned model and run inference\n",
25
+ "\n",
26
+ "> Note: The guide is designed to run on Google Colaboratory with access to an NVIDIA A100 GPU (40GB) using FunctionGemma. The workflow can be adapted to other GPU configurations, models, or environments."
27
+ ]
28
+ },
29
+ {
30
+ "cell_type": "markdown",
31
+ "metadata": {
32
+ "id": "duXYuR6Cu_na"
33
+ },
34
+ "source": [
35
+ "## What is GRPO and OpenEnv\n",
36
+ "\n",
37
+ "Group Relative Policy Optimization ([GRPO](https://huggingface.co/papers/2402.03300)) is a post-training method widely used for efficiently fine-tuning large language models. GRPO leverages reward functions to guide learning, enabling models to optimize task-specific behaviors without retraining the entire network.\n",
38
+ "\n",
39
+ "[OpenEnv](https://meta-pytorch.org/OpenEnv) provides a standard interface for interacting with agentic execution environments using simple Gymnasium-style APIs, such as `step()`, `reset()`, and `state()`. These APIs facilitate reinforcement learning training loops by allowing models to interact with environments in a structured manner. OpenEnv also offers tools for environment creators to build isolated, secure, and deployable environments that can be shared via common protocols like HTTP or packaged in Docker.\n",
40
+ "\n",
41
+ "The combination of GRPO and OpenEnv enables efficient fine-tuning of models in controlled, interactive tasks while minimizing resource requirements."
42
+ ]
43
+ },
44
+ {
45
+ "cell_type": "markdown",
46
+ "metadata": {
47
+ "id": "cpSAQkzKmv50"
48
+ },
49
+ "source": [
50
+ "## Setup dependencies for training\n",
51
+ "\n",
52
+ "Install the required libraries, including Hugging Face TRL for fine-tuning and OpenEnv for reinforcement learning environments."
53
+ ]
54
+ },
55
+ {
56
+ "cell_type": "code",
57
+ "execution_count": null,
58
+ "metadata": {
59
+ "id": "c-2drnj5BP56"
60
+ },
61
+ "outputs": [],
62
+ "source": [
63
+ "!pip install -Uq trl[vllm] git+https://huggingface.co/spaces/openenv/browsergym_env liger-kernel trackio"
64
+ ]
65
+ },
66
+ {
67
+ "cell_type": "markdown",
68
+ "metadata": {
69
+ "id": "Inxeq6ZGpRno"
70
+ },
71
+ "source": [
72
+ "A valid Hugging Face token is required to save the fine-tuned model. In Google Colab, the token can be securely accessed through Colab secrets. Otherwise, it can be provided directly in the login method. Ensure the token has write permissions to allow uploading the model to the Hugging Face Hub during training."
73
+ ]
74
+ },
75
+ {
76
+ "cell_type": "code",
77
+ "execution_count": null,
78
+ "metadata": {
79
+ "id": "C4q5UVu3BP57"
80
+ },
81
+ "outputs": [],
82
+ "source": [
83
+ "from google.colab import userdata\n",
84
+ "from huggingface_hub import login\n",
85
+ "\n",
86
+ "# Login into Hugging Face Hub\n",
87
+ "hf_token = userdata.get('HF_TOKEN') # If you are running inside a Google Colab\n",
88
+ "login(hf_token)"
89
+ ]
90
+ },
91
+ {
92
+ "cell_type": "markdown",
93
+ "metadata": {
94
+ "id": "O3kr38TGm_hb"
95
+ },
96
+ "source": [
97
+ "## Initialize the OpenEnv's BrowserGym environment\n",
98
+ "\n",
99
+ "External environments can guide the fine-tuning of LLMs for function calling by providing interactive feedback that enhances performance on task-specific behaviors.\n",
100
+ "\n",
101
+ "[BrowserGym](https://meta-pytorch.org/OpenEnv/environments/browsergym/) is a unified framework for web-based agent tasks, offering multiple benchmarks through a Gymnasium-compatible API. It enables training on simple synthetic tasks with [MiniWoB++](https://github.com/Farama-Foundation/miniwob-plusplus) and evaluation on more complex, realistic tasks with [WebArena](https://github.com/web-arena-x/webarena), [VisualWebArena](https://github.com/web-arena-x/visualwebarena), or [WorkArena](https://github.com/ServiceNow/WorkArena). This setup supports iterative training and assessment of web agents without requiring extensive infrastructure.\n",
102
+ "\n",
103
+ "BrowserGym supports both LLM and VLM training by providing visual information, including screenshots and DOM data, which can be utilized depending on the model type. This guide focuses on a simple web-based task called *\"click-test\"*, which is part of the MiniWoB++ benchmark of synthetic web tasks. Environments can be run locally, in Docker containers, or accessed remotely via the Hugging Face Hub. For this example, the remote environment [openenv/browsergym_env](https://huggingface.co/spaces/openenv/browsergym_env) will be used.\n",
104
+ "\n",
105
+ "> Note: Hosted environments on the Hub currently have limited concurrency. For higher reliability or parallel runs, duplicating the Space to your own account is strongly recommended."
106
+ ]
107
+ },
108
+ {
109
+ "cell_type": "code",
110
+ "execution_count": null,
111
+ "metadata": {
112
+ "id": "clDs-WQlBP57"
113
+ },
114
+ "outputs": [],
115
+ "source": [
116
+ "from browsergym_env import BrowserGymEnv\n",
117
+ "space_url = \"https://openenv-browsergym-env.hf.space\"\n",
118
+ "\n",
119
+ "client = BrowserGymEnv(base_url=space_url)"
120
+ ]
121
+ },
122
+ {
123
+ "cell_type": "markdown",
124
+ "metadata": {
125
+ "id": "EqfDavDQnD_5"
126
+ },
127
+ "source": [
128
+ "## Create rollout function with helpers\n",
129
+ "\n",
130
+ "The rollout function defines how the agent interacts with the environment during GRPO training. It generates model outputs, collects feedback in the form of rewards, and returns the information required for optimization.\n",
131
+ "\n",
132
+ "In this setup:\n",
133
+ "- The function is invoked automatically by the GRPOTrainer (introduced later), which orchestrates the training loop and handles policy updates.\n",
134
+ "- It uses the trainer's `generate_rollout_completions()` method for efficient output generation. This leverages vLLM, a high-performance inference engine for large language models, and is integrated within TRL to streamline rollout generation and reward collection during fine-tuning.\n",
135
+ "- Each rollout represents a complete interaction loop, where the model acts, receives feedback from the environment, and updates based on reward signals.\n",
136
+ "\n",
137
+ "Rewards capture various aspects of the agent's performance. Helper functions, such as `rollout_once`, manage individual episodes, keeping the main `rollout_func` clean, modular, and reusable.\n",
138
+ "\n",
139
+ "This modular structure allows GRPO to efficiently sample, evaluate, and refine the model's behavior through reinforcement learning.\n",
140
+ "\n",
141
+ "Before executing rollouts, a `system prompt` is defined to instruct the model on how to interact with the environment. This prompt specifies the available BrowserGym actions (such as `click`, `fill`, `send_keys`, and `scroll`), describes the page structure, and enforces that the model responds with exactly one action per step. It ensures consistent and structured interactions, guiding the model to complete tasks effectively without providing extra explanations or multiple actions."
142
+ ]
143
+ },
144
+ {
145
+ "cell_type": "code",
146
+ "execution_count": null,
147
+ "metadata": {
148
+ "id": "ItCXS6H0BP58"
149
+ },
150
+ "outputs": [],
151
+ "source": [
152
+ "# @title System prompt (click to expand)\n",
153
+ "SYSTEM_PROMPT = \"\"\"You control a web browser through BrowserGym actions.\n",
154
+ "You must complete the given web task by interacting with the page.\n",
155
+ "\n",
156
+ "Available actions:\n",
157
+ "- noop() - Do nothing\n",
158
+ "- click(bid) - Click element with BrowserGym ID (the number in brackets)\n",
159
+ "- fill(bid, text) - Fill input field with text\n",
160
+ "- send_keys(text) - Send keyboard input\n",
161
+ "- scroll(direction) - Scroll up/down\n",
162
+ "\n",
163
+ "The page structure shows elements as: [bid] element_type 'element_text'\n",
164
+ "For example: [13] button 'Click Me!' means bid='13'\n",
165
+ "\n",
166
+ "Reply with exactly ONE action on a single line, e.g.:\n",
167
+ "click('13')\n",
168
+ "fill('42', 'hello world')\n",
169
+ "noop()\n",
170
+ "\n",
171
+ "Do not include explanations or multiple actions.\"\"\""
172
+ ]
173
+ },
174
+ {
175
+ "cell_type": "markdown",
176
+ "metadata": {
177
+ "id": "Vi1rFey39GUl"
178
+ },
179
+ "source": [
180
+ "The `rollout_func` orchestrates the interaction between the model and the remote BrowserGym environment. For each prompt in the batch, it executes a complete episode using the `rollout_once` function, collecting model outputs and rewards for GRPO optimization.\n",
181
+ "\n",
182
+ "The parameter `max_steps` defines the maximum number of steps the model can take within a single episode. This limits the length of the interaction loop, ensuring that episodes terminate even if the task is not completed, and helps maintain efficient training.\n",
183
+ "\n",
184
+ "During each episode, the function tracks prompt and completion IDs, log probabilities, and both step-wise and final rewards, returning them in a structured format for the trainer to perform policy updates."
185
+ ]
186
+ },
187
+ {
188
+ "cell_type": "code",
189
+ "execution_count": null,
190
+ "metadata": {
191
+ "id": "CgHd5CFBBP58"
192
+ },
193
+ "outputs": [],
194
+ "source": [
195
+ "from trl import GRPOTrainer\n",
196
+ "\n",
197
+ "max_steps=10\n",
198
+ "\n",
199
+ "def rollout_func(prompts: list[str], trainer: GRPOTrainer) -> dict[str, list]:\n",
200
+ " episode_prompt_ids: list[list[int]] = []\n",
201
+ " episode_completion_ids: list[list[int]] = []\n",
202
+ " episode_logprobs: list[list[float]] = []\n",
203
+ " completion_rewards: list[float] = []\n",
204
+ "\n",
205
+ " print(f\"\\n[DEBUG] rollout_func called with {len(prompts)} prompts (LLM mode, text-only)\")\n",
206
+ "\n",
207
+ " for i, prompt_text in enumerate(prompts):\n",
208
+ " print(f\"[DEBUG] Processing prompt {i + 1}/{len(prompts)}\")\n",
209
+ " episode = rollout_once(\n",
210
+ " trainer=trainer,\n",
211
+ " env=client,\n",
212
+ " tokenizer=trainer.processing_class,\n",
213
+ " dataset_prompt=prompt_text,\n",
214
+ " max_steps=max_steps,\n",
215
+ " )\n",
216
+ " episode_prompt_ids.append(episode[\"prompt_ids\"])\n",
217
+ " episode_completion_ids.append(episode[\"completion_ids\"])\n",
218
+ " episode_logprobs.append(episode[\"logprobs\"])\n",
219
+ " completion_rewards.append(episode[\"completion_reward\"])\n",
220
+ "\n",
221
+ " return {\n",
222
+ " \"prompt_ids\": episode_prompt_ids,\n",
223
+ " \"completion_ids\": episode_completion_ids,\n",
224
+ " \"logprobs\": episode_logprobs,\n",
225
+ " \"completion_reward\": completion_rewards,\n",
226
+ " }"
227
+ ]
228
+ },
229
+ {
230
+ "cell_type": "markdown",
231
+ "metadata": {
232
+ "id": "ioUHdIxr9ZQO"
233
+ },
234
+ "source": [
235
+ "### Define `rollout_once`\n",
236
+ "\n",
237
+ "The `rollout_once` function runs one complete interaction loop between the model and the BrowserGym environment using the trainer's generation method. \n",
238
+ "It executes a single episode, from generating an action to receiving feedback and computing rewards.\n",
239
+ "\n",
240
+ "Here's the step-by-step breakdown:\n",
241
+ "\n",
242
+ "1. Environment reset: Start a new BrowserGym session and initialize the observation.\n",
243
+ "2. Prompt construction: Combine the system prompt, environment observation (text-only via the accessibility tree), and any relevant errors or state information to form the model input.\n",
244
+ "3. Generation: Use `trl.experimental.openenv.generate_rollout_completions()` to produce the model's action efficiently with vLLM.\n",
245
+ "4. Action parsing and execution: Interpret the model's output and execute the corresponding BrowserGym action (e.g., `click`, `fill`, `scroll`).\n",
246
+ "5. Reward calculation: Track step-wise rewards provided by the environment and compute completion rewards based on task success or failure.\n",
247
+ "6. Return structured rollout data: Includes prompt/completion IDs, log probabilities, step rewards, and the final reward for the episode.\n",
248
+ "\n",
249
+ "This modular design allows each episode to be processed independently while providing rich feedback for the GRPO training loop, supporting both task completion and intermediate reward shaping."
250
+ ]
251
+ },
252
+ {
253
+ "cell_type": "code",
254
+ "execution_count": null,
255
+ "metadata": {
256
+ "id": "y8Ml47SYBP58"
257
+ },
258
+ "outputs": [],
259
+ "source": [
260
+ "from trl.experimental.openenv import generate_rollout_completions\n",
261
+ "from browsergym_env import BrowserGymAction\n",
262
+ "from transformers import AutoTokenizer\n",
263
+ "\n",
264
+ "def rollout_once(\n",
265
+ " trainer: GRPOTrainer,\n",
266
+ " env: BrowserGymEnv,\n",
267
+ " tokenizer: AutoTokenizer,\n",
268
+ " dataset_prompt: str,\n",
269
+ " max_steps: int,\n",
270
+ ") -> dict[str, list]:\n",
271
+ " \"\"\"Run one episode and collect training data (text-only, no screenshots).\"\"\"\n",
272
+ " result = env.reset()\n",
273
+ " observation = result.observation\n",
274
+ "\n",
275
+ " prompt_ids: list[int] = []\n",
276
+ " completion_ids: list[int] = []\n",
277
+ " logprobs: list[float] = []\n",
278
+ " step_rewards: list[float] = []\n",
279
+ " completion_rewards: list[float] = []\n",
280
+ "\n",
281
+ " for step_num in range(max_steps):\n",
282
+ " if result.done:\n",
283
+ " break\n",
284
+ "\n",
285
+ " # Create prompt from observation (text-only using accessibility tree)\n",
286
+ " goal = observation.goal or dataset_prompt\n",
287
+ " axtree = observation.axtree_txt or \"\"\n",
288
+ " error = observation.error if observation.last_action_error else \"\"\n",
289
+ "\n",
290
+ " user_prompt = make_user_prompt(goal, step_num, axtree, error)\n",
291
+ " messages = [\n",
292
+ " {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n",
293
+ " {\"role\": \"user\", \"content\": user_prompt},\n",
294
+ " ]\n",
295
+ " prompt_text = tokenizer.apply_chat_template(\n",
296
+ " messages,\n",
297
+ " add_generation_prompt=True,\n",
298
+ " tokenize=False,\n",
299
+ " )\n",
300
+ "\n",
301
+ " # Generate action with vLLM\n",
302
+ " rollout_outputs = generate_rollout_completions(trainer, [prompt_text])[0]\n",
303
+ " prompt_ids.extend(rollout_outputs[\"prompt_ids\"])\n",
304
+ " completion_ids.extend(rollout_outputs[\"completion_ids\"])\n",
305
+ " logprobs.extend(rollout_outputs[\"logprobs\"])\n",
306
+ "\n",
307
+ " completion_text = rollout_outputs.get(\"text\") or tokenizer.decode(\n",
308
+ " rollout_outputs[\"completion_ids\"], skip_special_tokens=True\n",
309
+ " )\n",
310
+ "\n",
311
+ " # Parse and execute action\n",
312
+ " action_str = parse_action(completion_text)\n",
313
+ "\n",
314
+ " print(f\"Step {step_num + 1}: {action_str}\")\n",
315
+ "\n",
316
+ " # Take action in environment\n",
317
+ " result = env.step(BrowserGymAction(action_str=action_str))\n",
318
+ " observation = result.observation\n",
319
+ "\n",
320
+ " # Track rewards\n",
321
+ " step_reward = float(result.reward or 0.0)\n",
322
+ " step_rewards.append(step_reward)\n",
323
+ "\n",
324
+ " # Reward shaping: success is most important\n",
325
+ " if result.done and step_reward > 0:\n",
326
+ " completion_rewards.append(1.0) # Task completed successfully\n",
327
+ " elif result.done and step_reward == 0:\n",
328
+ " completion_rewards.append(0.0) # Task failed\n",
329
+ " else:\n",
330
+ " completion_rewards.append(step_reward) # Intermediate reward\n",
331
+ "\n",
332
+ " # Final reward is based on task completion\n",
333
+ " final_reward = completion_rewards[-1] if completion_rewards else 0.0\n",
334
+ "\n",
335
+ " return {\n",
336
+ " \"prompt_ids\": prompt_ids,\n",
337
+ " \"completion_ids\": completion_ids,\n",
338
+ " \"logprobs\": logprobs,\n",
339
+ " \"step_rewards\": step_rewards,\n",
340
+ " \"completion_reward\": final_reward,\n",
341
+ " }"
342
+ ]
343
+ },
344
+ {
345
+ "cell_type": "markdown",
346
+ "metadata": {
347
+ "id": "MDJKMQ__8qzj"
348
+ },
349
+ "source": [
350
+ "### Helper functions\n",
351
+ "\n",
352
+ "Supporting utilities used in `rollout_once`:\n",
353
+ "\n",
354
+ "- `make_user_prompt`: builds the user prompt combining the base text and previous game messages.\n",
355
+ "- `parse_action`: parses BrowserGym action from model response"
356
+ ]
357
+ },
358
+ {
359
+ "cell_type": "code",
360
+ "execution_count": null,
361
+ "metadata": {
362
+ "id": "GG4ba41PBP58"
363
+ },
364
+ "outputs": [],
365
+ "source": [
366
+ "# @title Helpers (click to expand)\n",
367
+ "def make_user_prompt(goal: str, step_num: int, axtree: str, error: str = \"\") -> str:\n",
368
+ " \"\"\"Create user prompt from observation.\"\"\"\n",
369
+ " prompt_parts = [f\"Step {step_num + 1}\"]\n",
370
+ "\n",
371
+ " if goal:\n",
372
+ " prompt_parts.append(f\"Goal: {goal}\")\n",
373
+ "\n",
374
+ " if error:\n",
375
+ " prompt_parts.append(f\"Previous action error: {error}\")\n",
376
+ "\n",
377
+ " # Include accessibility tree (truncated for context)\n",
378
+ " if axtree:\n",
379
+ " max_len = 2000\n",
380
+ " axtree_truncated = axtree[:max_len] + \"...\" if len(axtree) > max_len else axtree\n",
381
+ " prompt_parts.append(f\"Page structure:\\n{axtree_truncated}\")\n",
382
+ "\n",
383
+ " prompt_parts.append(\"What action do you take?\")\n",
384
+ "\n",
385
+ " return \"\\n\\n\".join(prompt_parts)\n",
386
+ "\n",
387
+ "\n",
388
+ "def parse_action(response_text: str) -> str:\n",
389
+ " \"\"\"Parse BrowserGym action from model response.\"\"\"\n",
390
+ " # Extract first line that looks like an action\n",
391
+ " for line in response_text.strip().split(\"\\n\"):\n",
392
+ " line = line.strip()\n",
393
+ " if \"(\" in line and \")\" in line:\n",
394
+ " return line\n",
395
+ "\n",
396
+ " # Fallback to noop if no valid action found\n",
397
+ " return \"noop()\""
398
+ ]
399
+ },
400
+ {
401
+ "cell_type": "markdown",
402
+ "metadata": {
403
+ "id": "Oek3JhcWnKhw"
404
+ },
405
+ "source": [
406
+ "## Define the reward functions\n",
407
+ "\n",
408
+ "Reward functions quantify the model's performance in the environment and guide the GRPO optimization process.\n",
409
+ "\n",
410
+ "In this setup, the `reward_completion` function assigns rewards based on task completion. It extracts the final reward for each episode, which indicates whether the agent successfully completed the task. If no reward information is available, it defaults to zero.\n",
411
+ "\n",
412
+ "This modular approach allows additional reward functions to be added easily, enabling more granular feedback such as intermediate progress, efficiency, or correctness of actions, depending on the task requirements."
413
+ ]
414
+ },
415
+ {
416
+ "cell_type": "code",
417
+ "execution_count": null,
418
+ "metadata": {
419
+ "id": "WxkXaz5aBP59"
420
+ },
421
+ "outputs": [],
422
+ "source": [
423
+ "def reward_completion(completions: list[str], **kwargs) -> list[float]:\n",
424
+ " \"\"\"Reward for task completion.\"\"\"\n",
425
+ " rewards = kwargs.get(\"completion_reward\") if kwargs else None\n",
426
+ " if rewards is None:\n",
427
+ " return [0.0 for _ in completions]\n",
428
+ " return [float(r) for r in rewards]"
429
+ ]
430
+ },
431
+ {
432
+ "cell_type": "markdown",
433
+ "metadata": {
434
+ "id": "66ZsrLplm07U"
435
+ },
436
+ "source": [
437
+ "## Load the custom dataset\n",
438
+ "\n",
439
+ "The dataset is constructed with repeated prompts to control the total number of training episodes.\n",
440
+ "\n",
441
+ "Each entry in the dataset triggers a single rollout episode during training. The `dataset_prompt` provides the initial instruction to the model at the start of each episode, ensuring consistent guidance for task execution."
442
+ ]
443
+ },
444
+ {
445
+ "cell_type": "code",
446
+ "execution_count": null,
447
+ "metadata": {
448
+ "id": "UX6jUjxaBP59"
449
+ },
450
+ "outputs": [],
451
+ "source": [
452
+ "from datasets import Dataset\n",
453
+ "\n",
454
+ "dataset_prompt = \"Complete the web task successfully.\"\n",
455
+ "dataset_size = 1000\n",
456
+ "\n",
457
+ "dataset = Dataset.from_dict({\"prompt\": [dataset_prompt] * dataset_size})"
458
+ ]
459
+ },
460
+ {
461
+ "cell_type": "markdown",
462
+ "metadata": {
463
+ "id": "-mvka-96m3I7"
464
+ },
465
+ "source": [
466
+ "## Fine-tune using TRL and the GRPOTrainer\n",
467
+ "\n",
468
+ "The next step is to define the GRPOConfig, which sets all key training parameters.\n",
469
+ "\n",
470
+ "This configuration determines how the model interacts with vLLM, handles memory and computation, and records training metrics and logs for monitoring the fine-tuning process."
471
+ ]
472
+ },
473
+ {
474
+ "cell_type": "code",
475
+ "execution_count": null,
476
+ "metadata": {
477
+ "id": "TZ34a1h-BP59"
478
+ },
479
+ "outputs": [],
480
+ "source": [
481
+ "from trl import GRPOConfig\n",
482
+ "output_dir = \"browsergym-grpo-functiongemma-270m-it\"\n",
483
+ "\n",
484
+ "grpo_config = GRPOConfig(\n",
485
+ " # num_train_epochs=1, # Number of times to iterate over the full dataset (use for full training runs)\n",
486
+ " max_steps=100, # Number of dataset passes (for shorter runs/testing). For full trainings, use `num_train_epochs` instead\n",
487
+ " learning_rate=5e-6, # Learning rate for the optimizer\n",
488
+ " warmup_steps=10, # Number of steps to linearly increase learning rate at the start of training\n",
489
+ "\n",
490
+ " per_device_train_batch_size=1, # Number of samples per device per step\n",
491
+ " num_generations=4, # Number of completions to generate per prompt\n",
492
+ " generation_batch_size=4, # Batch size used during generation (must be divisible by num_generations)\n",
493
+ " max_completion_length=32, # Maximum length of generated completions\n",
494
+ "\n",
495
+ " use_vllm=True, # Use vLLM engine for fast inference\n",
496
+ " vllm_mode=\"colocate\", # vLLM mode: \"colocate\" runs generation on the same GPU as training\n",
497
+ " vllm_gpu_memory_utilization=0.1, # Fraction of GPU memory allocated to vLLM\n",
498
+ "\n",
499
+ " output_dir=str(output_dir), # Directory where checkpoints, logs, and outputs will be saved\n",
500
+ " logging_steps=1, # Log metrics every N steps\n",
501
+ " report_to=\"trackio\", # Logging/reporting platform (e.g., \"trackio\")\n",
502
+ " trackio_space_id=output_dir, # HF Space where the experiment tracking will be saved\n",
503
+ " push_to_hub=True, # Optionally push trained model to Hugging Face Hub\n",
504
+ "\n",
505
+ " use_liger_kernel=True, # Enable Liger kernel optimizations for faster training\n",
506
+ ")\n"
507
+ ]
508
+ },
509
+ {
510
+ "cell_type": "markdown",
511
+ "metadata": {
512
+ "id": "a1taGmD--0Y4"
513
+ },
514
+ "source": [
515
+ "The next step is to initialize the GRPOTrainer, which manages the complete reinforcement learning loop.\n",
516
+ "\n",
517
+ "It receives the model name, reward functions, rollout function, and dataset defined earlier. From the model name, the trainer automatically initializes the model and tokenizer. It then coordinates interactions between the model and the environment, applies the defined reward signals, and updates the policy during training.\n",
518
+ "\n",
519
+ "Finally, calling `trainer.train()` starts the fine-tuning process, enabling the model to progressively improve its performance through iterative interaction and reinforcement learning.\n",
520
+ "\n",
521
+ "> Note: The training pipeline uses approximately 10.6 GB of GPU VRAM and can be adapted to different hardware configurations."
522
+ ]
523
+ },
524
+ {
525
+ "cell_type": "code",
526
+ "execution_count": null,
527
+ "metadata": {
528
+ "id": "En43o4NZBP59"
529
+ },
530
+ "outputs": [],
531
+ "source": [
532
+ "model_name = \"google/functiongemma-270m-it\""
533
+ ]
534
+ },
535
+ {
536
+ "cell_type": "code",
537
+ "execution_count": null,
538
+ "metadata": {
539
+ "colab": {
540
+ "referenced_widgets": [
541
+ "047d386e54704add95edd4beace781d7"
542
+ ]
543
+ },
544
+ "id": "k8-SvqJcBP59",
545
+ "outputId": "6a4d9276-fc91-4217-d3a2-51a18d222338"
546
+ },
547
+ "outputs": [
548
+ {
549
+ "name": "stderr",
550
+ "output_type": "stream",
551
+ "text": [
552
+ "/tmp/ipython-input-3830121904.py:1: UserWarning: You are importing from 'rollout_func', which is an experimental feature. This API may change or be removed at any time without prior notice. Silence this warning by setting environment variable TRL_EXPERIMENTAL_SILENCE=1.\n",
553
+ " trainer = GRPOTrainer(\n",
554
+ "The model is already on multiple devices. Skipping the move to device specified in `args`.\n",
555
+ "`torch_dtype` is deprecated! Use `dtype` instead!\n"
556
+ ]
557
+ },
558
+ {
559
+ "data": {
560
+ "application/vnd.jupyter.widget-view+json": {
561
+ "model_id": "047d386e54704add95edd4beace781d7",
562
+ "version_major": 2,
563
+ "version_minor": 0
564
+ },
565
+ "text/plain": [
566
+ "Loading safetensors checkpoint shards: 0% Completed | 0/1 [00:00<?, ?it/s]\n"
567
+ ]
568
+ },
569
+ "metadata": {},
570
+ "output_type": "display_data"
571
+ },
572
+ {
573
+ "name": "stderr",
574
+ "output_type": "stream",
575
+ "text": [
576
+ "Capturing CUDA graphs (mixed prefill-decode, PIECEWISE): 100%|██████████| 4/4 [00:00<00:00, 19.64it/s]\n"
577
+ ]
578
+ }
579
+ ],
580
+ "source": [
581
+ "trainer = GRPOTrainer(\n",
582
+ " model=model_name,\n",
583
+ " reward_funcs=[reward_completion],\n",
584
+ " train_dataset=dataset,\n",
585
+ " args=grpo_config,\n",
586
+ " rollout_func=rollout_func,\n",
587
+ ")"
588
+ ]
589
+ },
590
+ {
591
+ "cell_type": "code",
592
+ "execution_count": null,
593
+ "metadata": {
594
+ "id": "e1PrBB7gBP59",
595
+ "outputId": "61740a89-228c-4b3c-8e59-b4a3eb972c03"
596
+ },
597
+ "outputs": [
598
+ {
599
+ "name": "stderr",
600
+ "output_type": "stream",
601
+ "text": [
602
+ "The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'bos_token_id': 2, 'pad_token_id': 0}.\n"
603
+ ]
604
+ },
605
+ {
606
+ "name": "stdout",
607
+ "output_type": "stream",
608
+ "text": [
609
+ "* Trackio project initialized: huggingface\n",
610
+ "* Trackio metrics will be synced to Hugging Face Dataset: sergiopaniego/browsergym-grpo-functiongemma-270m-it-dataset\n",
611
+ "* Creating new space: https://huggingface.co/spaces/sergiopaniego/browsergym-grpo-functiongemma-270m-it\n",
612
+ "* View dashboard by going to: https://sergiopaniego-browsergym-grpo-functiongemma-270m-it.hf.space/\n"
613
+ ]
614
+ },
615
+ {
616
+ "data": {
617
+ "text/html": [
618
+ "<div><iframe src=\"https://sergiopaniego-browsergym-grpo-functiongemma-270m-it.hf.space/\" width=\"100%\" height=\"1000px\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
619
+ ],
620
+ "text/plain": [
621
+ "<IPython.core.display.HTML object>"
622
+ ]
623
+ },
624
+ "metadata": {},
625
+ "output_type": "display_data"
626
+ },
627
+ {
628
+ "name": "stdout",
629
+ "output_type": "stream",
630
+ "text": [
631
+ "* Created new run: sergiopaniego-1765969078\n",
632
+ "\n",
633
+ "[DEBUG] rollout_func called with 4 prompts (LLM mode, text-only)\n",
634
+ "[DEBUG] Processing prompt 1/4\n",
635
+ "Step 1: noop()\n",
636
+ "Step 2: noop()\n",
637
+ "Step 3: noop()\n",
638
+ "Step 4: noop()\n",
639
+ "Step 5: noop()\n",
640
+ "Step 6: noop()\n",
641
+ "Step 7: Click 'click(bid) - Click element with BrowserGym ID (the number in brackets\n",
642
+ "Step 8: I will use the action `click()` to click the button.\n",
643
+ "Step 9: noop()\n",
644
+ "Step 10: Click(bid) - Click element with BrowserGym ID (the number in brackets)\n",
645
+ "[DEBUG] Processing prompt 2/4\n",
646
+ "Step 1: noop()\n",
647
+ "Step 2: noop()\n",
648
+ "Step 3: Clicks ('13')\n",
649
+ "Step 4: I will click 'Click Me!' using action 'click(bid)' on page 'Click Test Task' using a bid of '13'.\n",
650
+ "Step 5: noop()\n",
651
+ "Step 6: noop()\n",
652
+ "Step 7: noop()\n",
653
+ "Step 8: noop()\n",
654
+ "Step 9: noop()\n",
655
+ "Step 10: noop()\n",
656
+ "[DEBUG] Processing prompt 3/4\n",
657
+ "Step 1: I will use the 'click(bid)' action.\n",
658
+ "Step 2: mouse_click(bid)\n",
659
+ "Step 3: click(bid) - Click element with BrowserGym ID (the number in brackets)\n",
660
+ "Step 4: Add action 'click(bid)' to Step 4.\n",
661
+ "Step 5: Click(bid) - Click element with BrowserGym ID (the number in brackets)\n",
662
+ "Step 6: noop()\n",
663
+ "Step 7: noop()\n",
664
+ "Step 8: click(bid) - Click element with BrowserGym ID (the number in brackets)\n",
665
+ "Step 9: noop()\n",
666
+ "Step 10: Click(bid) - Click element with BrowserGym ID (the number in brackets)\n",
667
+ "[DEBUG] Processing prompt 4/4\n",
668
+ "Step 1: noop()\n",
669
+ "Step 2: noop()\n",
670
+ "Step 3: noop()\n",
671
+ "Step 4: noop()\n",
672
+ "Step 5: Click('13')\n",
673
+ "Step 6: noop()\n",
674
+ "Step 7: noop()\n",
675
+ "Step 8: noop()\n",
676
+ "Step 9: noop()\n",
677
+ "Step 10: noop()\n"
678
+ ]
679
+ },
680
+ {
681
+ "name": "stderr",
682
+ "output_type": "stream",
683
+ "text": [
684
+ "WARNING:liger_kernel.transformers.model.gemma3:It is strongly recommended to train Gemma3 models with the `eager` attention implementation instead of `sdpa`. Use `eager` with `AutoModelForCausalLM.from_pretrained('<path-to-checkpoint>', attn_implementation='eager')`.\n",
685
+ "/usr/local/lib/python3.12/dist-packages/torch/_inductor/compile_fx.py:282: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.\n",
686
+ " warnings.warn(\n",
687
+ "/usr/local/lib/python3.12/dist-packages/torch/_inductor/lowering.py:7095: UserWarning: \n",
688
+ "Online softmax is disabled on the fly since Inductor decides to\n",
689
+ "split the reduction. Cut an issue to PyTorch if this is an\n",
690
+ "important use case and you want to speed it up with online\n",
691
+ "softmax.\n",
692
+ "\n",
693
+ " warnings.warn(\n"
694
+ ]
695
+ },
696
+ {
697
+ "data": {
698
+ "text/html": [
699
+ "\n",
700
+ " <div>\n",
701
+ " \n",
702
+ " <progress value='100' max='100' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
703
+ " [100/100 35:02, Epoch 0/1]\n",
704
+ " </div>\n",
705
+ " <table border=\"1\" class=\"dataframe\">\n",
706
+ " <thead>\n",
707
+ " <tr style=\"text-align: left;\">\n",
708
+ " <th>Step</th>\n",
709
+ " <th>Training Loss</th>\n",
710
+ " </tr>\n",
711
+ " </thead>\n",
712
+ " <tbody>\n",
713
+ " <tr>\n",
714
+ " <td>1</td>\n",
715
+ " <td>0.000000</td>\n",
716
+ " </tr>\n",
717
+ " <tr>\n",
718
+ " <td>2</td>\n",
719
+ " <td>0.000000</td>\n",
720
+ " </tr>\n",
721
+ " <tr>\n",
722
+ " <td>3</td>\n",
723
+ " <td>0.000000</td>\n",
724
+ " </tr>\n",
725
+ " <tr>\n",
726
+ " <td>4</td>\n",
727
+ " <td>0.000000</td>\n",
728
+ " </tr>\n",
729
+ " <tr>\n",
730
+ " <td>5</td>\n",
731
+ " <td>0.000000</td>\n",
732
+ " </tr>\n",
733
+ " <tr>\n",
734
+ " <td>6</td>\n",
735
+ " <td>0.000000</td>\n",
736
+ " </tr>\n",
737
+ " <tr>\n",
738
+ " <td>7</td>\n",
739
+ " <td>0.000000</td>\n",
740
+ " </tr>\n",
741
+ " <tr>\n",
742
+ " <td>8</td>\n",
743
+ " <td>0.000000</td>\n",
744
+ " </tr>\n",
745
+ " <tr>\n",
746
+ " <td>9</td>\n",
747
+ " <td>-0.877900</td>\n",
748
+ " </tr>\n",
749
+ " <tr>\n",
750
+ " <td>10</td>\n",
751
+ " <td>1965.894400</td>\n",
752
+ " </tr>\n",
753
+ " <tr>\n",
754
+ " <td>11</td>\n",
755
+ " <td>-0.830900</td>\n",
756
+ " </tr>\n",
757
+ " <tr>\n",
758
+ " <td>12</td>\n",
759
+ " <td>10.616100</td>\n",
760
+ " </tr>\n",
761
+ " <tr>\n",
762
+ " <td>13</td>\n",
763
+ " <td>0.000000</td>\n",
764
+ " </tr>\n",
765
+ " <tr>\n",
766
+ " <td>14</td>\n",
767
+ " <td>0.000000</td>\n",
768
+ " </tr>\n",
769
+ " <tr>\n",
770
+ " <td>15</td>\n",
771
+ " <td>0.000000</td>\n",
772
+ " </tr>\n",
773
+ " <tr>\n",
774
+ " <td>16</td>\n",
775
+ " <td>0.000000</td>\n",
776
+ " </tr>\n",
777
+ " <tr>\n",
778
+ " <td>17</td>\n",
779
+ " <td>2.320100</td>\n",
780
+ " </tr>\n",
781
+ " <tr>\n",
782
+ " <td>18</td>\n",
783
+ " <td>1.887500</td>\n",
784
+ " </tr>\n",
785
+ " <tr>\n",
786
+ " <td>19</td>\n",
787
+ " <td>-0.691600</td>\n",
788
+ " </tr>\n",
789
+ " <tr>\n",
790
+ " <td>20</td>\n",
791
+ " <td>-0.764400</td>\n",
792
+ " </tr>\n",
793
+ " <tr>\n",
794
+ " <td>21</td>\n",
795
+ " <td>0.000000</td>\n",
796
+ " </tr>\n",
797
+ " <tr>\n",
798
+ " <td>22</td>\n",
799
+ " <td>0.000000</td>\n",
800
+ " </tr>\n",
801
+ " <tr>\n",
802
+ " <td>23</td>\n",
803
+ " <td>0.000000</td>\n",
804
+ " </tr>\n",
805
+ " <tr>\n",
806
+ " <td>24</td>\n",
807
+ " <td>0.000000</td>\n",
808
+ " </tr>\n",
809
+ " <tr>\n",
810
+ " <td>25</td>\n",
811
+ " <td>0.000000</td>\n",
812
+ " </tr>\n",
813
+ " <tr>\n",
814
+ " <td>26</td>\n",
815
+ " <td>0.000000</td>\n",
816
+ " </tr>\n",
817
+ " <tr>\n",
818
+ " <td>27</td>\n",
819
+ " <td>0.000000</td>\n",
820
+ " </tr>\n",
821
+ " <tr>\n",
822
+ " <td>28</td>\n",
823
+ " <td>0.000000</td>\n",
824
+ " </tr>\n",
825
+ " <tr>\n",
826
+ " <td>29</td>\n",
827
+ " <td>0.000000</td>\n",
828
+ " </tr>\n",
829
+ " <tr>\n",
830
+ " <td>30</td>\n",
831
+ " <td>0.000000</td>\n",
832
+ " </tr>\n",
833
+ " <tr>\n",
834
+ " <td>31</td>\n",
835
+ " <td>0.000000</td>\n",
836
+ " </tr>\n",
837
+ " <tr>\n",
838
+ " <td>32</td>\n",
839
+ " <td>0.000000</td>\n",
840
+ " </tr>\n",
841
+ " <tr>\n",
842
+ " <td>33</td>\n",
843
+ " <td>0.000000</td>\n",
844
+ " </tr>\n",
845
+ " <tr>\n",
846
+ " <td>34</td>\n",
847
+ " <td>0.000000</td>\n",
848
+ " </tr>\n",
849
+ " <tr>\n",
850
+ " <td>35</td>\n",
851
+ " <td>0.000000</td>\n",
852
+ " </tr>\n",
853
+ " <tr>\n",
854
+ " <td>36</td>\n",
855
+ " <td>0.000000</td>\n",
856
+ " </tr>\n",
857
+ " <tr>\n",
858
+ " <td>37</td>\n",
859
+ " <td>0.000000</td>\n",
860
+ " </tr>\n",
861
+ " <tr>\n",
862
+ " <td>38</td>\n",
863
+ " <td>0.000000</td>\n",
864
+ " </tr>\n",
865
+ " <tr>\n",
866
+ " <td>39</td>\n",
867
+ " <td>0.000000</td>\n",
868
+ " </tr>\n",
869
+ " <tr>\n",
870
+ " <td>40</td>\n",
871
+ " <td>0.000000</td>\n",
872
+ " </tr>\n",
873
+ " <tr>\n",
874
+ " <td>41</td>\n",
875
+ " <td>0.000000</td>\n",
876
+ " </tr>\n",
877
+ " <tr>\n",
878
+ " <td>42</td>\n",
879
+ " <td>0.000000</td>\n",
880
+ " </tr>\n",
881
+ " <tr>\n",
882
+ " <td>43</td>\n",
883
+ " <td>0.000000</td>\n",
884
+ " </tr>\n",
885
+ " <tr>\n",
886
+ " <td>44</td>\n",
887
+ " <td>0.000000</td>\n",
888
+ " </tr>\n",
889
+ " <tr>\n",
890
+ " <td>45</td>\n",
891
+ " <td>0.000000</td>\n",
892
+ " </tr>\n",
893
+ " <tr>\n",
894
+ " <td>46</td>\n",
895
+ " <td>0.000000</td>\n",
896
+ " </tr>\n",
897
+ " <tr>\n",
898
+ " <td>47</td>\n",
899
+ " <td>0.000000</td>\n",
900
+ " </tr>\n",
901
+ " <tr>\n",
902
+ " <td>48</td>\n",
903
+ " <td>0.000000</td>\n",
904
+ " </tr>\n",
905
+ " <tr>\n",
906
+ " <td>49</td>\n",
907
+ " <td>0.000000</td>\n",
908
+ " </tr>\n",
909
+ " <tr>\n",
910
+ " <td>50</td>\n",
911
+ " <td>0.000000</td>\n",
912
+ " </tr>\n",
913
+ " <tr>\n",
914
+ " <td>51</td>\n",
915
+ " <td>0.000000</td>\n",
916
+ " </tr>\n",
917
+ " <tr>\n",
918
+ " <td>52</td>\n",
919
+ " <td>0.000000</td>\n",
920
+ " </tr>\n",
921
+ " <tr>\n",
922
+ " <td>53</td>\n",
923
+ " <td>0.000000</td>\n",
924
+ " </tr>\n",
925
+ " <tr>\n",
926
+ " <td>54</td>\n",
927
+ " <td>0.000000</td>\n",
928
+ " </tr>\n",
929
+ " <tr>\n",
930
+ " <td>55</td>\n",
931
+ " <td>0.000000</td>\n",
932
+ " </tr>\n",
933
+ " <tr>\n",
934
+ " <td>56</td>\n",
935
+ " <td>0.000000</td>\n",
936
+ " </tr>\n",
937
+ " <tr>\n",
938
+ " <td>57</td>\n",
939
+ " <td>0.000000</td>\n",
940
+ " </tr>\n",
941
+ " <tr>\n",
942
+ " <td>58</td>\n",
943
+ " <td>0.000000</td>\n",
944
+ " </tr>\n",
945
+ " <tr>\n",
946
+ " <td>59</td>\n",
947
+ " <td>0.000000</td>\n",
948
+ " </tr>\n",
949
+ " <tr>\n",
950
+ " <td>60</td>\n",
951
+ " <td>0.000000</td>\n",
952
+ " </tr>\n",
953
+ " <tr>\n",
954
+ " <td>61</td>\n",
955
+ " <td>0.000000</td>\n",
956
+ " </tr>\n",
957
+ " <tr>\n",
958
+ " <td>62</td>\n",
959
+ " <td>0.000000</td>\n",
960
+ " </tr>\n",
961
+ " <tr>\n",
962
+ " <td>63</td>\n",
963
+ " <td>0.000000</td>\n",
964
+ " </tr>\n",
965
+ " <tr>\n",
966
+ " <td>64</td>\n",
967
+ " <td>0.000000</td>\n",
968
+ " </tr>\n",
969
+ " <tr>\n",
970
+ " <td>65</td>\n",
971
+ " <td>0.000000</td>\n",
972
+ " </tr>\n",
973
+ " <tr>\n",
974
+ " <td>66</td>\n",
975
+ " <td>0.000000</td>\n",
976
+ " </tr>\n",
977
+ " <tr>\n",
978
+ " <td>67</td>\n",
979
+ " <td>0.000000</td>\n",
980
+ " </tr>\n",
981
+ " <tr>\n",
982
+ " <td>68</td>\n",
983
+ " <td>0.000000</td>\n",
984
+ " </tr>\n",
985
+ " <tr>\n",
986
+ " <td>69</td>\n",
987
+ " <td>0.000000</td>\n",
988
+ " </tr>\n",
989
+ " <tr>\n",
990
+ " <td>70</td>\n",
991
+ " <td>0.000000</td>\n",
992
+ " </tr>\n",
993
+ " <tr>\n",
994
+ " <td>71</td>\n",
995
+ " <td>0.000000</td>\n",
996
+ " </tr>\n",
997
+ " <tr>\n",
998
+ " <td>72</td>\n",
999
+ " <td>0.000000</td>\n",
1000
+ " </tr>\n",
1001
+ " <tr>\n",
1002
+ " <td>73</td>\n",
1003
+ " <td>0.000000</td>\n",
1004
+ " </tr>\n",
1005
+ " <tr>\n",
1006
+ " <td>74</td>\n",
1007
+ " <td>0.000000</td>\n",
1008
+ " </tr>\n",
1009
+ " <tr>\n",
1010
+ " <td>75</td>\n",
1011
+ " <td>0.000000</td>\n",
1012
+ " </tr>\n",
1013
+ " <tr>\n",
1014
+ " <td>76</td>\n",
1015
+ " <td>0.000000</td>\n",
1016
+ " </tr>\n",
1017
+ " <tr>\n",
1018
+ " <td>77</td>\n",
1019
+ " <td>0.000000</td>\n",
1020
+ " </tr>\n",
1021
+ " <tr>\n",
1022
+ " <td>78</td>\n",
1023
+ " <td>0.000000</td>\n",
1024
+ " </tr>\n",
1025
+ " <tr>\n",
1026
+ " <td>79</td>\n",
1027
+ " <td>0.000000</td>\n",
1028
+ " </tr>\n",
1029
+ " <tr>\n",
1030
+ " <td>80</td>\n",
1031
+ " <td>0.000000</td>\n",
1032
+ " </tr>\n",
1033
+ " <tr>\n",
1034
+ " <td>81</td>\n",
1035
+ " <td>0.000000</td>\n",
1036
+ " </tr>\n",
1037
+ " <tr>\n",
1038
+ " <td>82</td>\n",
1039
+ " <td>0.000000</td>\n",
1040
+ " </tr>\n",
1041
+ " <tr>\n",
1042
+ " <td>83</td>\n",
1043
+ " <td>0.000000</td>\n",
1044
+ " </tr>\n",
1045
+ " <tr>\n",
1046
+ " <td>84</td>\n",
1047
+ " <td>0.000000</td>\n",
1048
+ " </tr>\n",
1049
+ " <tr>\n",
1050
+ " <td>85</td>\n",
1051
+ " <td>0.000000</td>\n",
1052
+ " </tr>\n",
1053
+ " <tr>\n",
1054
+ " <td>86</td>\n",
1055
+ " <td>0.000000</td>\n",
1056
+ " </tr>\n",
1057
+ " <tr>\n",
1058
+ " <td>87</td>\n",
1059
+ " <td>0.000000</td>\n",
1060
+ " </tr>\n",
1061
+ " <tr>\n",
1062
+ " <td>88</td>\n",
1063
+ " <td>0.000000</td>\n",
1064
+ " </tr>\n",
1065
+ " <tr>\n",
1066
+ " <td>89</td>\n",
1067
+ " <td>0.000000</td>\n",
1068
+ " </tr>\n",
1069
+ " <tr>\n",
1070
+ " <td>90</td>\n",
1071
+ " <td>0.000000</td>\n",
1072
+ " </tr>\n",
1073
+ " <tr>\n",
1074
+ " <td>91</td>\n",
1075
+ " <td>0.000000</td>\n",
1076
+ " </tr>\n",
1077
+ " <tr>\n",
1078
+ " <td>92</td>\n",
1079
+ " <td>0.000000</td>\n",
1080
+ " </tr>\n",
1081
+ " <tr>\n",
1082
+ " <td>93</td>\n",
1083
+ " <td>0.000000</td>\n",
1084
+ " </tr>\n",
1085
+ " <tr>\n",
1086
+ " <td>94</td>\n",
1087
+ " <td>0.000000</td>\n",
1088
+ " </tr>\n",
1089
+ " <tr>\n",
1090
+ " <td>95</td>\n",
1091
+ " <td>0.000000</td>\n",
1092
+ " </tr>\n",
1093
+ " <tr>\n",
1094
+ " <td>96</td>\n",
1095
+ " <td>0.000000</td>\n",
1096
+ " </tr>\n",
1097
+ " <tr>\n",
1098
+ " <td>97</td>\n",
1099
+ " <td>0.000000</td>\n",
1100
+ " </tr>\n",
1101
+ " <tr>\n",
1102
+ " <td>98</td>\n",
1103
+ " <td>0.000000</td>\n",
1104
+ " </tr>\n",
1105
+ " <tr>\n",
1106
+ " <td>99</td>\n",
1107
+ " <td>0.000000</td>\n",
1108
+ " </tr>\n",
1109
+ " <tr>\n",
1110
+ " <td>100</td>\n",
1111
+ " <td>0.000000</td>\n",
1112
+ " </tr>\n",
1113
+ " </tbody>\n",
1114
+ "</table><p>"
1115
+ ],
1116
+ "text/plain": [
1117
+ "<IPython.core.display.HTML object>"
1118
+ ]
1119
+ },
1120
+ "metadata": {},
1121
+ "output_type": "display_data"
1122
+ },
1123
+ {
1124
+ "name": "stdout",
1125
+ "output_type": "stream",
1126
+ "text": [
1127
+ "\n",
1128
+ "[DEBUG] rollout_func called with 4 prompts (LLM mode, text-only)\n",
1129
+ "[DEBUG] Processing prompt 1/4\n",
1130
+ "Step 1: Clicks ('13')\n",
1131
+ "Step 2: noop()\n",
1132
+ "Step 3: noop()\n",
1133
+ "Step 4: noop()\n",
1134
+ "Step 5: noop()\n",
1135
+ "Step 6: Click(bid) - Click element with BrowserGym ID (the number in brackets)\n",
1136
+ "Step 7: noop()\n",
1137
+ "Step 8: noop()\n",
1138
+ "Step 9: click(bid) - Click element with BrowserGym ID (the number in brackets)\n",
1139
+ "Step 10: noop()\n",
1140
+ "[DEBUG] Processing prompt 2/4\n",
1141
+ "Step 1: noop()\n",
1142
+ "Step 2: I will use action: click(bid) to click the button.\n",
1143
+ "Step 3: Yes, I can handle this. I will use the `click()` action to click the button.\n",
1144
+ "Step 4: click(bid) - Click element with BrowserGym ID (the number in brackets)\n",
1145
+ "Step 5: noop()\n",
1146
+ "Step 6: noop()\n",
1147
+ "Step 7: noop()\n",
1148
+ "Step 8: Click(bid) - Click element with BrowserGym ID (the number in brackets)\n",
1149
+ "Step 9: noop()\n",
1150
+ "Step 10: click(bid) - Click element with BrowserGym ID (the number in brackets)\n",
1151
+ "[DEBUG] Processing prompt 3/4\n",
1152
+ "Step 1: click(bid) - Click element with BrowserGym ID (the number in brackets)\n",
1153
+ "Step 2: noop()\n",
1154
+ "Step 3: noop()\n",
1155
+ "Step 4: click(bid) - Click element with BrowserGym ID (the number in brackets)\n",
1156
+ "Step 5: noop()\n",
1157
+ "Step 6: noop()\n",
1158
+ "Step 7: click(bid) - Click element with BrowserGym ID (the number in brackets)\n",
1159
+ "Step 8: noop()\n",
1160
+ "Step 9: click(bid) - Click element with BrowserGym ID (the number in brackets)\n",
1161
+ "Step 10: Pass the button ID ('Click Me!') to the action \"click('bid')\".\n",
1162
+ "[DEBUG] Processing prompt 4/4\n",
1163
+ "Step 1: noop()\n",
1164
+ "Step 2: noop()\n",
1165
+ "Step 3: noop()\n",
1166
+ "Step 4: noop()\n",
1167
+ "Step 5: I will click the button by emitting `click(bid)` and `fill(bid, text)` simultaneously.\n",
1168
+ "Step 6: noop()\n",
1169
+ "Step 7: click(bid) - Click element with BrowserGym ID (the number in brackets)\n",
1170
+ "Step 8: noop()\n",
1171
+ "Step 9: noop()\n",
1172
+ "Step 10: noop()\n",
1173
+ "\n",
1174
+ "[DEBUG] rollout_func called with 4 prompts (LLM mode, text-only)\n",
1175
+ "[DEBUG] Processing prompt 1/4\n",
1176
+ "Step 1: - Noop()\n",
1177
+ "Step 2: noop()\n",
1178
+ "Step 3: -noop()\n",
1179
+ "Step 4: noop()\n",
1180
+ "Step 5: Click('13')\n",
1181
+ "Step 6: noop()\n",
1182
+ "Step 7: noop()\n",
1183
+ "Step 8: noop()\n",
1184
+ "Step 9: noop()\n",
1185
+ "Step 10: noop()\n",
1186
+ "[DEBUG] Processing prompt 2/4\n",
1187
+ "Step 1: noop()\n",
1188
+ "Step 2: click(bid) - Click element with BrowserGym ID (the number in brackets)\n",
1189
+ "Step 3: noop()\n",
1190
+ "Step 4: noop()\n",
1191
+ "Step 5: noop()\n",
1192
+ "Step 6: Complete action: click('13')\n",
1193
+ "[DEBUG] Processing prompt 3/4\n",
1194
+ "Step 1: I will use the action 'click('bid') to click the button.\n",
1195
+ "Step 2: noop()\n",
1196
+ "Step 3: noop()\n",
1197
+ "Step 4: noop()\n",
1198
+ "Step 5: noop()\n",
1199
+ "Step 6: I call action Click (bid) on the page.\n",
1200
+ "Step 7: noop()\n",
1201
+ "Step 8: noop()\n",
1202
+ "Step 9: noop()\n",
1203
+ "Step 10: noop()\n",
1204
+ "[DEBUG] Processing prompt 4/4\n",
1205
+ "Step 1: Oops()\n",
1206
+ "Step 2: noop()\n",
1207
+ "Step 3: fill(bid, text)\n",
1208
+ "Step 4: noop()\n",
1209
+ "Step 5: click('13')\n",
1210
+ "\n",
1211
+ "[DEBUG] rollout_func called with 4 prompts (LLM mode, text-only)\n",
1212
+ "[DEBUG] Processing prompt 1/4\n",
1213
+ "Step 1: def click_button_on_page():\n",
1214
+ "Step 2: noop()\n",
1215
+ "Step 3: click(bid)\n",
1216
+ "Step 4: Click('13')\n",
1217
+ "Step 5: noop()\n",
1218
+ "Step 6: noop()\n",
1219
+ "Step 7: noop()\n",
1220
+ "Step 8: noop()\n",
1221
+ "Step 9: noop()\n",
1222
+ "Step 10: noop()\n",
1223
+ "[DEBUG] Processing prompt 2/4\n",
1224
+ "Step 1: noop()\n",
1225
+ "Step 2: click(bid) - Click element with BrowserGym ID (the number in brackets)\n",
1226
+ "Step 3: noop()\n",
1227
+ "Step 4: click(bid) - Click element with BrowserGym ID (the number in brackets)\n",
1228
+ "Step 5: Click(bid) - Click element with BrowserGym ID (the number in brackets)\n",
1229
+ "Step 6: I will click the button 'Click Me!' by using the action `click(bid)` and emitting a bid of 13.\n",
1230
+ "Step 7: click(bid) - Click element with BrowserGym ID (the number in brackets)\n",
1231
+ "Step 8: noop()\n",
1232
+ "Step 9: noop()\n",
1233
+ "Step 10: noop()\n",
1234
+ "[DEBUG] Processing prompt 3/4\n",
1235
+ "Step 1: `click(bid)` - No action\n",
1236
+ "Step 2: - Noop()\n",
1237
+ "Step 3: noop()\n",
1238
+ "Step 4: noop()\n",
1239
+ "Step 5: noop()\n",
1240
+ "Step 6: noop()\n",
1241
+ "Step 7: noop()\n",
1242
+ "Step 8: noop()\n",
1243
+ "Step 9: noop()\n",
1244
+ "Step 10: I will click the button 'Click Me!' using the action 'click(bid)'.\n",
1245
+ "[DEBUG] Processing prompt 4/4\n",
1246
+ "Step 1: noop()\n",
1247
+ "Step 2: noop()\n",
1248
+ "Step 3: noop()\n",
1249
+ "Step 4: click(bid) - Click element with BrowserGym ID (the number in brackets)\n",
1250
+ "Step 5: noop()\n",
1251
+ "Step 6: noop()\n",
1252
+ "Step 7: noop()\n",
1253
+ "Step 8: noop()\n",
1254
+ "Step 9: Complete action: click(bid)\n",
1255
+ "Step 10: noop()\n",
1256
+ "\n",
1257
+ "[DEBUG] rollout_func called with 4 prompts (LLM mode, text-only)\n",
1258
+ "[DEBUG] Processing prompt 1/4\n",
1259
+ "Step 1: click('13')\n",
1260
+ "[DEBUG] Processing prompt 2/4\n",
1261
+ "Step 1: noop()\n",
1262
+ "Step 2: I will perform action 1: click('13') to complete the action.\n",
1263
+ "[DEBUG] Processing prompt 3/4\n",
1264
+ "Step 1: noop()\n",
1265
+ "Step 2: noop()\n",
1266
+ "Step 3: noop()\n",
1267
+ "Step 4: noop()\n",
1268
+ "Step 5: noop()\n",
1269
+ "Step 6: noop()\n",
1270
+ "Step 7: Click(bid) - Click element with BrowserGym ID (the number in brackets)\n",
1271
+ "Step 8: noop()\n",
1272
+ "Step 9: Click ('13')\n",
1273
+ "Step 10: Add action 'fill(bid, text) - Send keyboard input' to perform the click.\n",
1274
+ "[DEBUG] Processing prompt 4/4\n",
1275
+ "Step 1: noop()\n",
1276
+ "Step 2: Click('click(bid) - Bid')\n",
1277
+ "Step 3: noop()\n",
1278
+ "Step 4: noop()\n",
1279
+ "Step 5: noop()\n",
1280
+ "Step 6: noop()\n",
1281
+ "Step 7: noop()\n",
1282
+ "Step 8: noop()\n",
1283
+ "Step 9: click(bid) - Click element with BrowserGym ID (the number in brackets)\n",
1284
+ "Step 10: noop()\n",
1285
+ "\n",
1286
+ "[DEBUG] rollout_func called with 4 prompts (LLM mode, text-only)\n",
1287
+ "[DEBUG] Processing prompt 1/4\n",
1288
+ "Step 1: click('13')\n",
1289
+ "[DEBUG] Processing prompt 2/4\n",
1290
+ "Step 1: click('13')\n",
1291
+ "[DEBUG] Processing prompt 3/4\n",
1292
+ "Step 1: click('13')\n",
1293
+ "[DEBUG] Processing prompt 4/4\n",
1294
+ "Step 1: click('13')\n",
1295
+ "\n",
1296
+ "[DEBUG] rollout_func called with 4 prompts (LLM mode, text-only)\n",
1297
+ "[DEBUG] Processing prompt 1/4\n",
1298
+ "Step 1: click('13')\n",
1299
+ "[DEBUG] Processing prompt 2/4\n",
1300
+ "Step 1: click('13')\n",
1301
+ "[DEBUG] Processing prompt 3/4\n",
1302
+ "Step 1: click('13')\n",
1303
+ "[DEBUG] Processing prompt 4/4\n",
1304
+ "Step 1: click('13')\n",
1305
+ "\n",
1306
+ "[DEBUG] rollout_func called with 4 prompts (LLM mode, text-only)\n",
1307
+ "[DEBUG] Processing prompt 1/4\n",
1308
+ "Step 1: click('13')\n",
1309
+ "[DEBUG] Processing prompt 2/4\n",
1310
+ "Step 1: click('13')\n",
1311
+ "[DEBUG] Processing prompt 3/4\n",
1312
+ "Step 1: click('13')\n",
1313
+ "[DEBUG] Processing prompt 4/4\n",
1314
+ "Step 1: click('13')\n",
1315
+ "\n",
1316
+ "[DEBUG] rollout_func called with 4 prompts (LLM mode, text-only)\n",
1317
+ "[DEBUG] Processing prompt 1/4\n",
1318
+ "Step 1: click('13')\n",
1319
+ "[DEBUG] Processing prompt 2/4\n",
1320
+ "Step 1: click('13')\n",
1321
+ "[DEBUG] Processing prompt 3/4\n",
1322
+ "Step 1: click('13')\n",
1323
+ "[DEBUG] Processing prompt 4/4\n",
1324
+ "Step 1: click('13')\n",
1325
+ "\n",
1326
+ "[DEBUG] rollout_func called with 4 prompts (LLM mode, text-only)\n",
1327
+ "[DEBUG] Processing prompt 1/4\n",
1328
+ "Step 1: click('13')\n",
1329
+ "[DEBUG] Processing prompt 2/4\n",
1330
+ "Step 1: click('13')\n",
1331
+ "[DEBUG] Processing prompt 3/4\n",
1332
+ "Step 1: click('13')\n",
1333
+ "[DEBUG] Processing prompt 4/4\n",
1334
+ "Step 1: click('13')\n",
1335
+ "\n",
1336
+ "[DEBUG] rollout_func called with 4 prompts (LLM mode, text-only)\n",
1337
+ "[DEBUG] Processing prompt 1/4\n",
1338
+ "Step 1: click('13')\n",
1339
+ "[DEBUG] Processing prompt 2/4\n",
1340
+ "Step 1: click('13')\n",
1341
+ "[DEBUG] Processing prompt 3/4\n",
1342
+ "Step 1: click('13')\n",
1343
+ "[DEBUG] Processing prompt 4/4\n",
1344
+ "Step 1: click('13')\n",
1345
+ "\n",
1346
+ "[DEBUG] rollout_func called with 4 prompts (LLM mode, text-only)\n",
1347
+ "[DEBUG] Processing prompt 1/4\n",
1348
+ "Step 1: click('13')\n",
1349
+ "[DEBUG] Processing prompt 2/4\n",
1350
+ "Step 1: click('13')\n",
1351
+ "[DEBUG] Processing prompt 3/4\n",
1352
+ "Step 1: click('13')\n",
1353
+ "[DEBUG] Processing prompt 4/4\n",
1354
+ "Step 1: click('13')\n",
1355
+ "\n",
1356
+ "[DEBUG] rollout_func called with 4 prompts (LLM mode, text-only)\n",
1357
+ "[DEBUG] Processing prompt 1/4\n",
1358
+ "Step 1: click('13')\n",
1359
+ "[DEBUG] Processing prompt 2/4\n",
1360
+ "Step 1: click('13')\n",
1361
+ "[DEBUG] Processing prompt 3/4\n",
1362
+ "Step 1: click('13')\n",
1363
+ "[DEBUG] Processing prompt 4/4\n",
1364
+ "Step 1: click('13')\n",
1365
+ "\n",
1366
+ "[DEBUG] rollout_func called with 4 prompts (LLM mode, text-only)\n",
1367
+ "[DEBUG] Processing prompt 1/4\n",
1368
+ "Step 1: click('13')\n",
1369
+ "[DEBUG] Processing prompt 2/4\n",
1370
+ "Step 1: click('13')\n",
1371
+ "[DEBUG] Processing prompt 3/4\n",
1372
+ "Step 1: click('13')\n",
1373
+ "[DEBUG] Processing prompt 4/4\n",
1374
+ "Step 1: click('13')\n",
1375
+ "\n",
1376
+ "[DEBUG] rollout_func called with 4 prompts (LLM mode, text-only)\n",
1377
+ "[DEBUG] Processing prompt 1/4\n",
1378
+ "Step 1: click('13')\n",
1379
+ "[DEBUG] Processing prompt 2/4\n",
1380
+ "Step 1: click('13')\n",
1381
+ "[DEBUG] Processing prompt 3/4\n",
1382
+ "Step 1: click('13')\n",
1383
+ "[DEBUG] Processing prompt 4/4\n",
1384
+ "Step 1: click('13')\n",
1385
+ "\n",
1386
+ "[DEBUG] rollout_func called with 4 prompts (LLM mode, text-only)\n",
1387
+ "[DEBUG] Processing prompt 1/4\n",
1388
+ "Step 1: click('13')\n",
1389
+ "[DEBUG] Processing prompt 2/4\n",
1390
+ "Step 1: click('13')\n",
1391
+ "[DEBUG] Processing prompt 3/4\n",
1392
+ "Step 1: click('13')\n",
1393
+ "[DEBUG] Processing prompt 4/4\n",
1394
+ "Step 1: click('13')\n",
1395
+ "\n",
1396
+ "[DEBUG] rollout_func called with 4 prompts (LLM mode, text-only)\n",
1397
+ "[DEBUG] Processing prompt 1/4\n",
1398
+ "Step 1: click('13')\n",
1399
+ "[DEBUG] Processing prompt 2/4\n",
1400
+ "Step 1: click('13')\n",
1401
+ "[DEBUG] Processing prompt 3/4\n",
1402
+ "Step 1: click('13')\n",
1403
+ "[DEBUG] Processing prompt 4/4\n",
1404
+ "Step 1: click('13')\n",
1405
+ "\n",
1406
+ "[DEBUG] rollout_func called with 4 prompts (LLM mode, text-only)\n",
1407
+ "[DEBUG] Processing prompt 1/4\n",
1408
+ "Step 1: click('13')\n",
1409
+ "[DEBUG] Processing prompt 2/4\n",
1410
+ "Step 1: click('13')\n",
1411
+ "[DEBUG] Processing prompt 3/4\n",
1412
+ "Step 1: click('13')\n",
1413
+ "[DEBUG] Processing prompt 4/4\n",
1414
+ "Step 1: click('13')\n",
1415
+ "\n",
1416
+ "[DEBUG] rollout_func called with 4 prompts (LLM mode, text-only)\n",
1417
+ "[DEBUG] Processing prompt 1/4\n",
1418
+ "Step 1: click('13')\n",
1419
+ "[DEBUG] Processing prompt 2/4\n",
1420
+ "Step 1: click('13')\n",
1421
+ "[DEBUG] Processing prompt 3/4\n",
1422
+ "Step 1: click('13')\n",
1423
+ "[DEBUG] Processing prompt 4/4\n",
1424
+ "Step 1: click('13')\n",
1425
+ "\n",
1426
+ "[DEBUG] rollout_func called with 4 prompts (LLM mode, text-only)\n",
1427
+ "[DEBUG] Processing prompt 1/4\n",
1428
+ "Step 1: click('13')\n",
1429
+ "[DEBUG] Processing prompt 2/4\n",
1430
+ "Step 1: click('13')\n",
1431
+ "[DEBUG] Processing prompt 3/4\n",
1432
+ "Step 1: click('13')\n",
1433
+ "[DEBUG] Processing prompt 4/4\n",
1434
+ "Step 1: click('13')\n",
1435
+ "\n",
1436
+ "[DEBUG] rollout_func called with 4 prompts (LLM mode, text-only)\n",
1437
+ "[DEBUG] Processing prompt 1/4\n",
1438
+ "Step 1: click('13')\n",
1439
+ "[DEBUG] Processing prompt 2/4\n",
1440
+ "Step 1: click('13')\n",
1441
+ "[DEBUG] Processing prompt 3/4\n",
1442
+ "Step 1: click('13')\n",
1443
+ "[DEBUG] Processing prompt 4/4\n",
1444
+ "Step 1: click('13')\n",
1445
+ "\n",
1446
+ "[DEBUG] rollout_func called with 4 prompts (LLM mode, text-only)\n",
1447
+ "[DEBUG] Processing prompt 1/4\n",
1448
+ "Step 1: click('13')\n",
1449
+ "[DEBUG] Processing prompt 2/4\n",
1450
+ "Step 1: click('13')\n",
1451
+ "[DEBUG] Processing prompt 3/4\n",
1452
+ "Step 1: click('13')\n",
1453
+ "[DEBUG] Processing prompt 4/4\n",
1454
+ "Step 1: click('13')\n",
1455
+ "\n",
1456
+ "[DEBUG] rollout_func called with 4 prompts (LLM mode, text-only)\n",
1457
+ "[DEBUG] Processing prompt 1/4\n",
1458
+ "Step 1: click('13')\n",
1459
+ "[DEBUG] Processing prompt 2/4\n",
1460
+ "Step 1: click('13')\n",
1461
+ "[DEBUG] Processing prompt 3/4\n",
1462
+ "Step 1: click('13')\n",
1463
+ "[DEBUG] Processing prompt 4/4\n",
1464
+ "Step 1: click('13')\n",
1465
+ "\n",
1466
+ "[DEBUG] rollout_func called with 4 prompts (LLM mode, text-only)\n",
1467
+ "[DEBUG] Processing prompt 1/4\n",
1468
+ "Step 1: click('13')\n",
1469
+ "[DEBUG] Processing prompt 2/4\n",
1470
+ "Step 1: click('13')\n",
1471
+ "[DEBUG] Processing prompt 3/4\n",
1472
+ "Step 1: click('13')\n",
1473
+ "[DEBUG] Processing prompt 4/4\n",
1474
+ "Step 1: click('13')\n",
1475
+ "\n",
1476
+ "[DEBUG] rollout_func called with 4 prompts (LLM mode, text-only)\n",
1477
+ "[DEBUG] Processing prompt 1/4\n",
1478
+ "Step 1: click('13')\n",
1479
+ "[DEBUG] Processing prompt 2/4\n",
1480
+ "Step 1: click('13')\n",
1481
+ "[DEBUG] Processing prompt 3/4\n",
1482
+ "Step 1: click('13')\n",
1483
+ "[DEBUG] Processing prompt 4/4\n",
1484
+ "Step 1: click('13')\n",
1485
+ "* Run finished. Uploading logs to Trackio (please wait...)\n"
1486
+ ]
1487
+ }
1488
+ ],
1489
+ "source": [
1490
+ "trainer_stats = trainer.train()"
1491
+ ]
1492
+ },
1493
+ {
1494
+ "cell_type": "markdown",
1495
+ "metadata": {
1496
+ "id": "BZj4IG9ZBAix"
1497
+ },
1498
+ "source": [
1499
+ "In this step, the fine-tuned model is saved locally and uploaded to the Hugging Face Hub using the configured account credentials."
1500
+ ]
1501
+ },
1502
+ {
1503
+ "cell_type": "code",
1504
+ "execution_count": null,
1505
+ "metadata": {
1506
+ "colab": {
1507
+ "referenced_widgets": [
1508
+ "244ced1920694dbaae9bf98065b4f01d",
1509
+ "e3769ae107554c9ba38c1e491b15bf4e",
1510
+ "6d5b8bff73474faeb1d1b438fb4e8cec",
1511
+ "9f952f8eb63b42e4b38711737da5461e",
1512
+ "bd12780895064467b5be14e2ec3df114",
1513
+ "d1261c1083a74dca877e6eece6395d73",
1514
+ "999744cacd6a4fb08a1d4977ce2f06fd",
1515
+ "faa5e0fb4ee244689c0f9eef9902acf7",
1516
+ "6403bed2cd984ba18f74f416748c64e4",
1517
+ "38be017369524e2eb22050e7a0a18ec5",
1518
+ "b0720a4a2df948308011d4d87a288426",
1519
+ "889ca2520f4d446daf2e6ed16ce11d2e"
1520
+ ]
1521
+ },
1522
+ "id": "9oOBgEWeBP59",
1523
+ "outputId": "76bef375-fc6b-4fdd-a296-549a9b109b11"
1524
+ },
1525
+ "outputs": [
1526
+ {
1527
+ "data": {
1528
+ "application/vnd.jupyter.widget-view+json": {
1529
+ "model_id": "244ced1920694dbaae9bf98065b4f01d",
1530
+ "version_major": 2,
1531
+ "version_minor": 0
1532
+ },
1533
+ "text/plain": [
1534
+ "Processing Files (0 / 0) : | | 0.00B / 0.00B "
1535
+ ]
1536
+ },
1537
+ "metadata": {},
1538
+ "output_type": "display_data"
1539
+ },
1540
+ {
1541
+ "data": {
1542
+ "application/vnd.jupyter.widget-view+json": {
1543
+ "model_id": "e3769ae107554c9ba38c1e491b15bf4e",
1544
+ "version_major": 2,
1545
+ "version_minor": 0
1546
+ },
1547
+ "text/plain": [
1548
+ "New Data Upload : | | 0.00B / 0.00B "
1549
+ ]
1550
+ },
1551
+ "metadata": {},
1552
+ "output_type": "display_data"
1553
+ },
1554
+ {
1555
+ "data": {
1556
+ "application/vnd.jupyter.widget-view+json": {
1557
+ "model_id": "6d5b8bff73474faeb1d1b438fb4e8cec",
1558
+ "version_major": 2,
1559
+ "version_minor": 0
1560
+ },
1561
+ "text/plain": [
1562
+ " ...270m-it/training_args.bin: 100%|##########| 7.57kB / 7.57kB "
1563
+ ]
1564
+ },
1565
+ "metadata": {},
1566
+ "output_type": "display_data"
1567
+ },
1568
+ {
1569
+ "data": {
1570
+ "application/vnd.jupyter.widget-view+json": {
1571
+ "model_id": "9f952f8eb63b42e4b38711737da5461e",
1572
+ "version_major": 2,
1573
+ "version_minor": 0
1574
+ },
1575
+ "text/plain": [
1576
+ " ...a-270m-it/tokenizer.model: 100%|##########| 4.69MB / 4.69MB "
1577
+ ]
1578
+ },
1579
+ "metadata": {},
1580
+ "output_type": "display_data"
1581
+ },
1582
+ {
1583
+ "data": {
1584
+ "application/vnd.jupyter.widget-view+json": {
1585
+ "model_id": "bd12780895064467b5be14e2ec3df114",
1586
+ "version_major": 2,
1587
+ "version_minor": 0
1588
+ },
1589
+ "text/plain": [
1590
+ " ...ma-270m-it/tokenizer.json: 100%|##########| 33.4MB / 33.4MB "
1591
+ ]
1592
+ },
1593
+ "metadata": {},
1594
+ "output_type": "display_data"
1595
+ },
1596
+ {
1597
+ "data": {
1598
+ "application/vnd.jupyter.widget-view+json": {
1599
+ "model_id": "d1261c1083a74dca877e6eece6395d73",
1600
+ "version_major": 2,
1601
+ "version_minor": 0
1602
+ },
1603
+ "text/plain": [
1604
+ " ...270m-it/model.safetensors: 4%|3 | 41.9MB / 1.07GB "
1605
+ ]
1606
+ },
1607
+ "metadata": {},
1608
+ "output_type": "display_data"
1609
+ },
1610
+ {
1611
+ "name": "stderr",
1612
+ "output_type": "stream",
1613
+ "text": [
1614
+ "No files have been modified since last commit. Skipping to prevent empty commit.\n",
1615
+ "WARNING:huggingface_hub.hf_api:No files have been modified since last commit. Skipping to prevent empty commit.\n"
1616
+ ]
1617
+ },
1618
+ {
1619
+ "data": {
1620
+ "application/vnd.jupyter.widget-view+json": {
1621
+ "model_id": "999744cacd6a4fb08a1d4977ce2f06fd",
1622
+ "version_major": 2,
1623
+ "version_minor": 0
1624
+ },
1625
+ "text/plain": [
1626
+ "Processing Files (0 / 0) : | | 0.00B / 0.00B "
1627
+ ]
1628
+ },
1629
+ "metadata": {},
1630
+ "output_type": "display_data"
1631
+ },
1632
+ {
1633
+ "data": {
1634
+ "application/vnd.jupyter.widget-view+json": {
1635
+ "model_id": "faa5e0fb4ee244689c0f9eef9902acf7",
1636
+ "version_major": 2,
1637
+ "version_minor": 0
1638
+ },
1639
+ "text/plain": [
1640
+ "New Data Upload : | | 0.00B / 0.00B "
1641
+ ]
1642
+ },
1643
+ "metadata": {},
1644
+ "output_type": "display_data"
1645
+ },
1646
+ {
1647
+ "data": {
1648
+ "application/vnd.jupyter.widget-view+json": {
1649
+ "model_id": "6403bed2cd984ba18f74f416748c64e4",
1650
+ "version_major": 2,
1651
+ "version_minor": 0
1652
+ },
1653
+ "text/plain": [
1654
+ " ...270m-it/training_args.bin: 100%|##########| 7.57kB / 7.57kB "
1655
+ ]
1656
+ },
1657
+ "metadata": {},
1658
+ "output_type": "display_data"
1659
+ },
1660
+ {
1661
+ "data": {
1662
+ "application/vnd.jupyter.widget-view+json": {
1663
+ "model_id": "38be017369524e2eb22050e7a0a18ec5",
1664
+ "version_major": 2,
1665
+ "version_minor": 0
1666
+ },
1667
+ "text/plain": [
1668
+ " ...a-270m-it/tokenizer.model: 100%|##########| 4.69MB / 4.69MB "
1669
+ ]
1670
+ },
1671
+ "metadata": {},
1672
+ "output_type": "display_data"
1673
+ },
1674
+ {
1675
+ "data": {
1676
+ "application/vnd.jupyter.widget-view+json": {
1677
+ "model_id": "b0720a4a2df948308011d4d87a288426",
1678
+ "version_major": 2,
1679
+ "version_minor": 0
1680
+ },
1681
+ "text/plain": [
1682
+ " ...270m-it/model.safetensors: 3%|3 | 33.5MB / 1.07GB "
1683
+ ]
1684
+ },
1685
+ "metadata": {},
1686
+ "output_type": "display_data"
1687
+ },
1688
+ {
1689
+ "data": {
1690
+ "application/vnd.jupyter.widget-view+json": {
1691
+ "model_id": "889ca2520f4d446daf2e6ed16ce11d2e",
1692
+ "version_major": 2,
1693
+ "version_minor": 0
1694
+ },
1695
+ "text/plain": [
1696
+ " ...ma-270m-it/tokenizer.json: 100%|##########| 33.4MB / 33.4MB "
1697
+ ]
1698
+ },
1699
+ "metadata": {},
1700
+ "output_type": "display_data"
1701
+ },
1702
+ {
1703
+ "name": "stderr",
1704
+ "output_type": "stream",
1705
+ "text": [
1706
+ "No files have been modified since last commit. Skipping to prevent empty commit.\n",
1707
+ "WARNING:huggingface_hub.hf_api:No files have been modified since last commit. Skipping to prevent empty commit.\n"
1708
+ ]
1709
+ },
1710
+ {
1711
+ "data": {
1712
+ "application/vnd.google.colaboratory.intrinsic+json": {
1713
+ "type": "string"
1714
+ },
1715
+ "text/plain": [
1716
+ "CommitInfo(commit_url='https://huggingface.co/sergiopaniego/browsergym-grpo-functiongemma-270m-it/commit/a17de133c28ca7fddfcb2694c32f2791de5ddbe6', commit_message='End of training', commit_description='', oid='a17de133c28ca7fddfcb2694c32f2791de5ddbe6', pr_url=None, repo_url=RepoUrl('https://huggingface.co/sergiopaniego/browsergym-grpo-functiongemma-270m-it', endpoint='https://huggingface.co', repo_type='model', repo_id='sergiopaniego/browsergym-grpo-functiongemma-270m-it'), pr_revision=None, pr_num=None)"
1717
+ ]
1718
+ },
1719
+ "execution_count": 12,
1720
+ "metadata": {},
1721
+ "output_type": "execute_result"
1722
+ }
1723
+ ],
1724
+ "source": [
1725
+ "trainer.save_model(output_dir)\n",
1726
+ "trainer.push_to_hub()"
1727
+ ]
1728
+ },
1729
+ {
1730
+ "cell_type": "markdown",
1731
+ "metadata": {
1732
+ "id": "talmc8b7nPXJ"
1733
+ },
1734
+ "source": [
1735
+ "## Load the Fine-Tuned Model and Run Inference\n",
1736
+ "\n",
1737
+ "The fine-tuned model is loaded to perform inference and evaluate its behavior on the target task. \n",
1738
+ "In this case, the model is tested within the BrowserGym environment using OpenEnv, focusing on the *click* task from the MiniWoB++ benchmark, which is included among the available BrowserGym tasks."
1739
+ ]
1740
+ },
1741
+ {
1742
+ "cell_type": "code",
1743
+ "execution_count": null,
1744
+ "metadata": {
1745
+ "colab": {
1746
+ "referenced_widgets": [
1747
+ "c3879b716f37442a87d51b8414fe8c48"
1748
+ ]
1749
+ },
1750
+ "id": "iIDiaGVlBP5-",
1751
+ "outputId": "4dc0e365-e89f-40ba-b391-74c7efdc932d"
1752
+ },
1753
+ "outputs": [
1754
+ {
1755
+ "data": {
1756
+ "application/vnd.jupyter.widget-view+json": {
1757
+ "model_id": "c3879b716f37442a87d51b8414fe8c48",
1758
+ "version_major": 2,
1759
+ "version_minor": 0
1760
+ },
1761
+ "text/plain": [
1762
+ "model.safetensors: 0%| | 0.00/1.07G [00:00<?, ?B/s]"
1763
+ ]
1764
+ },
1765
+ "metadata": {},
1766
+ "output_type": "display_data"
1767
+ }
1768
+ ],
1769
+ "source": [
1770
+ "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
1771
+ "\n",
1772
+ "model_name = \"sergiopaniego/browsergym-grpo-functiongemma-270m-it\" # Replace with your HF username or organization\n",
1773
+ "\n",
1774
+ "fine_tuned_model = AutoModelForCausalLM.from_pretrained(model_name, dtype=\"float32\", device_map=\"auto\")\n",
1775
+ "tokenizer = AutoTokenizer.from_pretrained(model_name)"
1776
+ ]
1777
+ },
1778
+ {
1779
+ "cell_type": "markdown",
1780
+ "metadata": {
1781
+ "id": "lyT-vudO5ekj"
1782
+ },
1783
+ "source": [
1784
+ "With the fine-tuned model loaded, testing can be conducted on the BrowserGym environment.\n",
1785
+ "To streamline evaluation, a reusable function is defined that executes multiple rounds of the task.\n",
1786
+ "This function follows the same interaction logic as used during training, generating model actions from observations, executing them in the environment, and printing the results step by step."
1787
+ ]
1788
+ },
1789
+ {
1790
+ "cell_type": "code",
1791
+ "execution_count": null,
1792
+ "metadata": {
1793
+ "id": "doAEIf5IBP5-"
1794
+ },
1795
+ "outputs": [],
1796
+ "source": [
1797
+ "def test_click_in_browsergym(env, model, tokenizer):\n",
1798
+ " result = env.reset()\n",
1799
+ " observation = result.observation\n",
1800
+ "\n",
1801
+ " for step_num in range(max_steps):\n",
1802
+ " if result.done:\n",
1803
+ " break\n",
1804
+ "\n",
1805
+ " # Create prompt from observation (text-only using accessibility tree)\n",
1806
+ " goal = observation.goal or dataset_prompt\n",
1807
+ " axtree = observation.axtree_txt or \"\"\n",
1808
+ " error = observation.error if observation.last_action_error else \"\"\n",
1809
+ "\n",
1810
+ " user_prompt = make_user_prompt(goal, step_num, axtree, error)\n",
1811
+ " messages = [\n",
1812
+ " {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n",
1813
+ " {\"role\": \"user\", \"content\": user_prompt},\n",
1814
+ " ]\n",
1815
+ " prompt_text = tokenizer.apply_chat_template(\n",
1816
+ " messages,\n",
1817
+ " add_generation_prompt=True,\n",
1818
+ " tokenize=False,\n",
1819
+ " )\n",
1820
+ "\n",
1821
+ " # Generate action\n",
1822
+ " prompt_text = tokenizer.apply_chat_template(\n",
1823
+ " messages,\n",
1824
+ " add_generation_prompt=True,\n",
1825
+ " tokenize=False,\n",
1826
+ " enable_thinking=False,\n",
1827
+ " )\n",
1828
+ "\n",
1829
+ " model_inputs = tokenizer([prompt_text], return_tensors=\"pt\").to(model.device)\n",
1830
+ "\n",
1831
+ " generated_ids = model.generate(\n",
1832
+ " **model_inputs,\n",
1833
+ " max_new_tokens=512\n",
1834
+ " )\n",
1835
+ " output_ids = generated_ids[0][len(model_inputs.input_ids[0]):]\n",
1836
+ "\n",
1837
+ " # Decode and extract model response\n",
1838
+ " generated_text = tokenizer.decode(output_ids, skip_special_tokens=True)\n",
1839
+ "\n",
1840
+ " action_str = parse_action(generated_text)\n",
1841
+ " print(f\"Step {step_num + 1}: {action_str}\")\n",
1842
+ "\n",
1843
+ " # Take action in environment\n",
1844
+ " result = env.step(BrowserGymAction(action_str=action_str))\n",
1845
+ " observation = result.observation"
1846
+ ]
1847
+ },
1848
+ {
1849
+ "cell_type": "markdown",
1850
+ "metadata": {
1851
+ "id": "9QvGD8f8CQx1"
1852
+ },
1853
+ "source": [
1854
+ "The `test_click_in_browsergym` function is called to run a full evaluation of the fine-tuned model on the BrowserGym *click* task. \n",
1855
+ "\n",
1856
+ "The environment client is safely closed after testing using a `try/finally` block, ensuring that all resources are released even if an error occurs during execution."
1857
+ ]
1858
+ },
1859
+ {
1860
+ "cell_type": "code",
1861
+ "execution_count": null,
1862
+ "metadata": {
1863
+ "id": "Z77wlVb6BP5-",
1864
+ "outputId": "ed4ad094-1529-4cc7-8274-2782784efe2d"
1865
+ },
1866
+ "outputs": [
1867
+ {
1868
+ "name": "stdout",
1869
+ "output_type": "stream",
1870
+ "text": [
1871
+ "Step 1: click('13')\n"
1872
+ ]
1873
+ }
1874
+ ],
1875
+ "source": [
1876
+ "try:\n",
1877
+ " test_click_in_browsergym(client, fine_tuned_model, tokenizer)\n",
1878
+ "finally:\n",
1879
+ " client.close()"
1880
+ ]
1881
+ },
1882
+ {
1883
+ "cell_type": "markdown",
1884
+ "metadata": {
1885
+ "id": "wHydP-ZVCcYK"
1886
+ },
1887
+ "source": [
1888
+ "## Summary and Next Steps\n",
1889
+ "\n",
1890
+ "This tutorial demonstrated how to fine-tune a FunctionGemma model using TRL, GRPO, and the BrowserGym environment from OpenEnv. Check out the following docs next:\n",
1891
+ "\n",
1892
+ "- Learn how to [generate text with a Gemma model](https://ai.google.dev/gemma/docs/get_started).\n",
1893
+ "- Learn how to [fine-tune Gemma for vision tasks using Hugging Face Transformers](https://ai.google.dev/gemma/docs/core/huggingface_vision_finetune_qlora).\n",
1894
+ "- Learn how to [full model fine-tune using Hugging Face Transformers](https://ai.google.dev/gemma/docs/core/huggingface_text_full_finetune).\n",
1895
+ "- Learn how to [fine-tune Gemma using Hugging Face Transformers with QLoRA](https://ai.google.dev/gemma/docs/core/huggingface_text_finetune_qlora). \n",
1896
+ "- Learn how to perform [distributed fine-tuning and inference on a Gemma model](https://ai.google.dev/gemma/docs/core/distributed_tuning).\n",
1897
+ "- Learn how to [use Gemma open models with Vertex AI](https://cloud.google.com/vertex-ai/docs/generative-ai/open-models/use-gemma).\n",
1898
+ "- Learn how to [fine-tune Gemma using KerasNLP and deploy to Vertex AI](https://github.com/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/community/model_garden/model_garden_gemma_kerasnlp_to_vertexai.ipynb)."
1899
+ ]
1900
+ }
1901
+ ],
1902
+ "metadata": {
1903
+ "accelerator": "GPU",
1904
+ "colab": {
1905
+ "gpuType": "A100",
1906
+ "provenance": []
1907
+ },
1908
+ "language_info": {
1909
+ "name": "python"
1910
+ }
1911
+ },
1912
+ "nbformat": 4,
1913
+ "nbformat_minor": 0
1914
+ }
ICL/RL/trl_source/examples/notebooks/grpo_ministral3_vl.ipynb ADDED
@@ -0,0 +1,740 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "id": "-J8iGzLf4rUJ"
7
+ },
8
+ "source": [
9
+ "# GRPO Ministral-3 with QLoRA using TRL\n",
10
+ "\n",
11
+ "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/notebooks/grpo_ministral3_vl.ipynb)\n",
12
+ "\n",
13
+ "![trl banner](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl_banner_dark.png)\n",
14
+ "\n",
15
+ "\n",
16
+ "With [**Transformers Reinforcement Learning (TRL)**](https://github.com/huggingface/trl), you can fine-tune cutting edge vision language models. It comes with support for quantized parameter efficient fine-tuning technique **QLoRA**, so we can use free Colab (T4 GPU) to fine-tune models like [Ministral-3](https://huggingface.co/collections/mistralai/ministral-3).\n",
17
+ "\n",
18
+ "\n",
19
+ "- [TRL GitHub Repository](https://github.com/huggingface/trl) — star us to support the project! \n",
20
+ "- [Official TRL Examples (notebooks and scripts)](https://huggingface.co/docs/trl/example_overview) \n",
21
+ "- [Community Tutorials](https://huggingface.co/docs/trl/community_tutorials)"
22
+ ]
23
+ },
24
+ {
25
+ "cell_type": "markdown",
26
+ "metadata": {
27
+ "id": "NvrzGRnu48Vz"
28
+ },
29
+ "source": [
30
+ "## Install dependencies\n",
31
+ "\n",
32
+ "We'll install **TRL** with the **PEFT** extra, which ensures all main dependencies such as **Transformers** and **PEFT** (a package for parameter-efficient fine-tuning, e.g., LoRA/QLoRA) are included. Additionally, we'll install **trackio** to log and monitor our experiments, and **bitsandbytes** to enable quantization of LLMs, reducing memory consumption for both inference and training."
33
+ ]
34
+ },
35
+ {
36
+ "cell_type": "code",
37
+ "execution_count": null,
38
+ "metadata": {
39
+ "id": "Dbvb3UmQ99p9",
40
+ "outputId": "3ad47e9a-017e-4066-8fe8-77a59586fff3"
41
+ },
42
+ "outputs": [],
43
+ "source": [
44
+ "!pip install -Uq \"trl[peft]\" bitsandbytes trackio math_verify git+https://github.com/huggingface/transformers mistral-common"
45
+ ]
46
+ },
47
+ {
48
+ "cell_type": "markdown",
49
+ "metadata": {
50
+ "id": "gpzI6omi7728"
51
+ },
52
+ "source": [
53
+ "### Log in to Hugging Face\n",
54
+ "\n",
55
+ "Log in to your **Hugging Face** account to save your fine-tuned model, track your experiment results directly on the Hub or access gated models. You can find your **access token** on your [account settings page](https://huggingface.co/settings/tokens)."
56
+ ]
57
+ },
58
+ {
59
+ "cell_type": "code",
60
+ "execution_count": null,
61
+ "metadata": {
62
+ "colab": {
63
+ "referenced_widgets": [
64
+ "2ac44d3c070845af86d9b2e3ce8b949f"
65
+ ]
66
+ },
67
+ "id": "h5Ubc70Z99p-",
68
+ "outputId": "633485d3-c79b-4702-ac01-f5a7be5cadfb"
69
+ },
70
+ "outputs": [],
71
+ "source": [
72
+ "from huggingface_hub import notebook_login\n",
73
+ "\n",
74
+ "notebook_login()"
75
+ ]
76
+ },
77
+ {
78
+ "cell_type": "markdown",
79
+ "metadata": {
80
+ "id": "V_Zylc4t79-n"
81
+ },
82
+ "source": [
83
+ "## Load dataset\n",
84
+ "\n",
85
+ "\n",
86
+ "We'll load the [**lmms-lab/multimodal-open-r1-8k-verified**](https://huggingface.co/datasets/lmms-lab/multimodal-open-r1-8k-verified) dataset from the Hugging Face Hub using the `datasets` library.\n",
87
+ "\n",
88
+ "This dataset contains maths problems with the image representing the problem, along with the solution in thinking format specially tailored for VLMs. By training our model with this dataset, it'll improve its maths and thinking reasoning.\n"
89
+ ]
90
+ },
91
+ {
92
+ "cell_type": "code",
93
+ "execution_count": null,
94
+ "metadata": {
95
+ "colab": {
96
+ "referenced_widgets": [
97
+ "3538a24e7f63433d91144b0ef765d8f0",
98
+ "23c73818302c4c879d7eca629b4d734d",
99
+ "663a6d37e74c4663a0d5c31aa14b47d6"
100
+ ]
101
+ },
102
+ "id": "OsyilesY99p-",
103
+ "outputId": "4cca7fa0-5f49-4c40-e36a-3a87d2496177"
104
+ },
105
+ "outputs": [],
106
+ "source": [
107
+ "from datasets import load_dataset\n",
108
+ "\n",
109
+ "dataset_id = 'lmms-lab/multimodal-open-r1-8k-verified'\n",
110
+ "train_dataset = load_dataset(dataset_id, split='train[:5%]')"
111
+ ]
112
+ },
113
+ {
114
+ "cell_type": "markdown",
115
+ "metadata": {
116
+ "id": "gVV7RoRN8zk5"
117
+ },
118
+ "source": [
119
+ "In addition to the `problem` and `image` columns, we also include a custom system prompt to tell the model how we'd like the generation.\n",
120
+ "\n",
121
+ "The system prompt is extracted from DeepSeek R1. Refer to [this previous recipe](https://huggingface.co/learn/cookbook/fine_tuning_llm_grpo_trl) for more details.\n",
122
+ "\n",
123
+ "We convert the dataset samples into conversation samples, including the system prompt and one image and problem description per sample, since this is how the GRPO trainer expects them.\n",
124
+ "\n",
125
+ "We also set `padding_side=\"left\"` to ensure that generated completions during training are concatenated directly after the prompt, which is essential for GRPO to correctly compare token-level probabilities between preferred and rejected responses.\n",
126
+ "\n",
127
+ "> **Note:**\n",
128
+ "> In older GPUs (including those available on Colab), **FP8 support** is limited, so we use the BF16 version of the model.\n",
129
+ "> In that case, you can select the official checkpoint or the one from Unsloth.\n",
130
+ "> If you have access to GPUs with **FP8 support**, you can switch to that version instead."
131
+ ]
132
+ },
133
+ {
134
+ "cell_type": "code",
135
+ "execution_count": null,
136
+ "metadata": {
137
+ "colab": {
138
+ "referenced_widgets": [
139
+ "83dfeaab2bd04b06899d09b6b35bacd1",
140
+ "8588996c1d2d444193e9cf53c1a73b8e",
141
+ "138a997da09f40ada32171e51b51b708",
142
+ "06ef4d5f41de4436ad4731cbf2f8471f"
143
+ ]
144
+ },
145
+ "id": "WlK7KYKT99p-",
146
+ "outputId": "db72808f-21cf-4022-ed1a-b78ebb3ee47e"
147
+ },
148
+ "outputs": [],
149
+ "source": [
150
+ "from transformers import AutoProcessor\n",
151
+ "\n",
152
+ "#model_name = \"mistralai/Ministral-3-3B-Instruct-2512\"\n",
153
+ "model_name = \"mistralai/Ministral-3-3B-Instruct-2512-BF16\" # \"unsloth/Ministral-3-3B-Instruct-2512\"\n",
154
+ "\n",
155
+ "processor = AutoProcessor.from_pretrained(model_name, padding_side=\"left\")\n",
156
+ "\n",
157
+ "SYSTEM_PROMPT = (\n",
158
+ " \"You are a helpful AI Assistant that provides well-reasoned and detailed responses. \"\n",
159
+ " \"You first think about the reasoning process as an internal monologue and then provide the user with the answer. \"\n",
160
+ " \"Respond in the following format: <think>\\n...\\n</think>\\n<answer>\\n...\\n</answer>\"\n",
161
+ ")\n",
162
+ "\n",
163
+ "\n",
164
+ "def make_conversation(example):\n",
165
+ " conversation = [\n",
166
+ " {\n",
167
+ " \"role\": \"system\",\n",
168
+ " \"content\": [{\"type\": \"text\", \"text\": SYSTEM_PROMPT}],\n",
169
+ " },\n",
170
+ " {\n",
171
+ " \"role\": \"user\",\n",
172
+ " \"content\": [\n",
173
+ " {\"type\": \"image\", \"image\": example[\"image\"]},\n",
174
+ " {\"type\": \"text\", \"text\": example[\"problem\"]},\n",
175
+ " ],\n",
176
+ " },\n",
177
+ " ]\n",
178
+ " return {\n",
179
+ " \"prompt\": conversation,\n",
180
+ " \"image\": example[\"image\"],\n",
181
+ " }\n",
182
+ "\n",
183
+ "train_dataset = train_dataset.map(make_conversation)"
184
+ ]
185
+ },
186
+ {
187
+ "cell_type": "markdown",
188
+ "metadata": {
189
+ "id": "5txAuMAa8ock"
190
+ },
191
+ "source": [
192
+ "Let's review one example to understand the internal structure:"
193
+ ]
194
+ },
195
+ {
196
+ "cell_type": "code",
197
+ "execution_count": null,
198
+ "metadata": {
199
+ "id": "sjxG7duU99p_"
200
+ },
201
+ "outputs": [],
202
+ "source": [
203
+ "train_dataset[0]"
204
+ ]
205
+ },
206
+ {
207
+ "cell_type": "code",
208
+ "execution_count": null,
209
+ "metadata": {
210
+ "id": "ZooycTF099p_"
211
+ },
212
+ "outputs": [],
213
+ "source": [
214
+ "train_dataset = train_dataset.remove_columns(['problem', 'original_question', 'original_answer'])"
215
+ ]
216
+ },
217
+ {
218
+ "cell_type": "code",
219
+ "execution_count": null,
220
+ "metadata": {
221
+ "id": "2LcjFKgD99p_"
222
+ },
223
+ "outputs": [],
224
+ "source": [
225
+ "train_dataset[0]"
226
+ ]
227
+ },
228
+ {
229
+ "cell_type": "markdown",
230
+ "metadata": {
231
+ "id": "YY3uMp909Eqy"
232
+ },
233
+ "source": [
234
+ "## Load model and configure LoRA/QLoRA\n",
235
+ "\n",
236
+ "This notebook can be used with two fine-tuning methods. By default, it is set up for **QLoRA**, which includes quantization using `BitsAndBytesConfig`. If you prefer to use standard **LoRA** without quantization, simply comment out the `BitsAndBytesConfig` configuration."
237
+ ]
238
+ },
239
+ {
240
+ "cell_type": "code",
241
+ "execution_count": null,
242
+ "metadata": {
243
+ "id": "RcQn7mGs99p_"
244
+ },
245
+ "outputs": [],
246
+ "source": [
247
+ "from transformers import Mistral3ForConditionalGeneration, FineGrainedFP8Config, BitsAndBytesConfig\n",
248
+ "import torch\n",
249
+ "\n",
250
+ "FP8 = False\n",
251
+ "\n",
252
+ "if FP8:\n",
253
+ " model_name = \"mistralai/Ministral-3-3B-Instruct-2512\"\n",
254
+ " quantization_config = FineGrainedFP8Config(dequantize=False)\n",
255
+ "else:\n",
256
+ " model_name = \"mistralai/Ministral-3-3B-Instruct-2512-BF16\" # \"unsloth/Ministral-3-3B-Instruct-2512\"\n",
257
+ " quantization_config = BitsAndBytesConfig(\n",
258
+ " load_in_4bit=True, # Load the model in 4-bit precision to save memory\n",
259
+ " bnb_4bit_compute_dtype=torch.float16, # Data type used for internal computations in quantization\n",
260
+ " bnb_4bit_use_double_quant=True, # Use double quantization to improve accuracy\n",
261
+ " bnb_4bit_quant_type=\"nf4\", # Type of quantization. \"nf4\" is recommended for recent LLMs\n",
262
+ " )\n",
263
+ "\n",
264
+ "model = Mistral3ForConditionalGeneration.from_pretrained(\n",
265
+ " model_name,\n",
266
+ " dtype=\"float32\",\n",
267
+ " device_map=\"auto\",\n",
268
+ " quantization_config=quantization_config,\n",
269
+ ")"
270
+ ]
271
+ },
272
+ {
273
+ "cell_type": "markdown",
274
+ "metadata": {
275
+ "id": "WZGf-GF09Gsc"
276
+ },
277
+ "source": [
278
+ "The following cell defines LoRA (or QLoRA if needed). When training with LoRA/QLoRA, we use a **base model** (the one selected above) and, instead of modifying its original weights, we fine-tune a **LoRA adapter** — a lightweight layer that enables efficient and memory-friendly training. The **`target_modules`** specify which parts of the model (e.g., attention or projection layers) will be adapted by LoRA during fine-tuning."
279
+ ]
280
+ },
281
+ {
282
+ "cell_type": "code",
283
+ "execution_count": null,
284
+ "metadata": {
285
+ "id": "LqCEI4hf99p_"
286
+ },
287
+ "outputs": [],
288
+ "source": [
289
+ "from peft import LoraConfig\n",
290
+ "\n",
291
+ "# You may need to update `target_modules` depending on the architecture of your chosen model.\n",
292
+ "# For example, different VLMs might have different attention/projection layer names.\n",
293
+ "peft_config = LoraConfig(\n",
294
+ " r=8,\n",
295
+ " lora_alpha=32,\n",
296
+ " lora_dropout=0.1,\n",
297
+ " target_modules=[\"q_proj\", \"v_proj\"],\n",
298
+ ")"
299
+ ]
300
+ },
301
+ {
302
+ "cell_type": "markdown",
303
+ "metadata": {
304
+ "id": "mDq4V6dN9MGk"
305
+ },
306
+ "source": [
307
+ "## Train model\n",
308
+ "\n",
309
+ "We'll configure **GRPO** using `GRPOConfig`, keeping the parameters minimal so the training fits on a free Colab instance. You can adjust these settings if more resources are available. For full details on all available parameters, check the [TRL GRPOConfig documentation](https://huggingface.co/docs/trl/sft_trainer#trl.GRPOConfig).\n",
310
+ "\n",
311
+ "First, we need to define the rewards functions that the training algorithm will use to improve the model. In this case, we'll include two reward functions.\n",
312
+ "We'll use a format reward that will reward the model when the output includes `<think>` and `<answer>` tags and additionally a length-based reward to discourage overthinking. Both functions have been extracted from [here](https://github.com/huggingface/open-r1/blob/main/src/open_r1/rewards.py)."
313
+ ]
314
+ },
315
+ {
316
+ "cell_type": "code",
317
+ "execution_count": null,
318
+ "metadata": {
319
+ "id": "jhgqx8kO99p_"
320
+ },
321
+ "outputs": [],
322
+ "source": [
323
+ "import re\n",
324
+ "\n",
325
+ "def format_reward(completions, **kwargs):\n",
326
+ " \"\"\"Reward function that checks if the reasoning process is enclosed within <think> and </think> tags, while the final answer is enclosed within <answer> and </answer> tags.\"\"\"\n",
327
+ " pattern = r\"<think>.*?</think>.*?<answer>.*?</answer>\"\n",
328
+ "\n",
329
+ " matches = []\n",
330
+ " for item in completions:\n",
331
+ " if isinstance(item, list):\n",
332
+ " text = item[0]['content']\n",
333
+ " else:\n",
334
+ " text = item\n",
335
+ " match = re.match(pattern, text, re.DOTALL | re.MULTILINE)\n",
336
+ " matches.append(match)\n",
337
+ "\n",
338
+ " return [1.0 if match else 0.0 for match in matches]"
339
+ ]
340
+ },
341
+ {
342
+ "cell_type": "code",
343
+ "execution_count": null,
344
+ "metadata": {
345
+ "id": "sVmzQ_wL99p_"
346
+ },
347
+ "outputs": [],
348
+ "source": [
349
+ "from math_verify import LatexExtractionConfig, parse, verify\n",
350
+ "from latex2sympy2_extended import NormalizationConfig\n",
351
+ "\n",
352
+ "\n",
353
+ "def len_reward(completions, solution, **kwargs) -> float:\n",
354
+ " \"\"\"Compute length-based rewards to discourage overthinking and promote token efficiency.\n",
355
+ "\n",
356
+ " Taken from the Kimi 1.5 tech report: https://huggingface.co/papers/2501.12599\n",
357
+ "\n",
358
+ " Args:\n",
359
+ " completions: List of model completions\n",
360
+ " solution: List of ground truth solutions\n",
361
+ "\n",
362
+ " Returns:\n",
363
+ " List of rewards where:\n",
364
+ " - For correct answers: reward = 0.5 - (len - min_len)/(max_len - min_len)\n",
365
+ " - For incorrect answers: reward = min(0, 0.5 - (len - min_len)/(max_len - min_len))\n",
366
+ " \"\"\"\n",
367
+ " contents = []\n",
368
+ " for item in completions:\n",
369
+ " if isinstance(item, list):\n",
370
+ " text = item[0]['content']\n",
371
+ " else:\n",
372
+ " text = item\n",
373
+ " contents.append(text)\n",
374
+ "\n",
375
+ " # First check correctness of answers\n",
376
+ " correctness = []\n",
377
+ " for content, sol in zip(contents, solution):\n",
378
+ " gold_parsed = parse(\n",
379
+ " sol,\n",
380
+ " extraction_mode=\"first_match\",\n",
381
+ " extraction_config=[LatexExtractionConfig()],\n",
382
+ " )\n",
383
+ " if len(gold_parsed) == 0:\n",
384
+ " # Skip unparsable examples\n",
385
+ " correctness.append(True) # Treat as correct to avoid penalizing\n",
386
+ " print(\"Failed to parse gold solution: \", sol)\n",
387
+ " continue\n",
388
+ "\n",
389
+ " answer_parsed = parse(\n",
390
+ " content,\n",
391
+ " extraction_config=[\n",
392
+ " LatexExtractionConfig(\n",
393
+ " normalization_config=NormalizationConfig(\n",
394
+ " nits=False,\n",
395
+ " malformed_operators=False,\n",
396
+ " basic_latex=True,\n",
397
+ " equations=True,\n",
398
+ " boxed=True,\n",
399
+ " units=True,\n",
400
+ " ),\n",
401
+ " boxed_match_priority=0,\n",
402
+ " try_extract_without_anchor=False,\n",
403
+ " )\n",
404
+ " ],\n",
405
+ " extraction_mode=\"first_match\",\n",
406
+ " )\n",
407
+ " correctness.append(verify(answer_parsed, gold_parsed))\n",
408
+ "\n",
409
+ " # Calculate lengths\n",
410
+ " lengths = [len(content) for content in contents]\n",
411
+ " min_len = min(lengths)\n",
412
+ " max_len = max(lengths)\n",
413
+ "\n",
414
+ " # If all responses have the same length, return zero rewards\n",
415
+ " if max_len == min_len:\n",
416
+ " return [0.0] * len(completions)\n",
417
+ "\n",
418
+ " rewards = []\n",
419
+ " for length, is_correct in zip(lengths, correctness):\n",
420
+ " lambda_val = 0.5 - (length - min_len) / (max_len - min_len)\n",
421
+ "\n",
422
+ " if is_correct:\n",
423
+ " reward = lambda_val\n",
424
+ " else:\n",
425
+ " reward = min(0, lambda_val)\n",
426
+ "\n",
427
+ " rewards.append(float(reward))\n",
428
+ "\n",
429
+ " return rewards"
430
+ ]
431
+ },
432
+ {
433
+ "cell_type": "markdown",
434
+ "metadata": {
435
+ "id": "9xBL7Rni9LZb"
436
+ },
437
+ "source": [
438
+ "After defining the reward function(s), we can define the `GRPOConfig`."
439
+ ]
440
+ },
441
+ {
442
+ "cell_type": "code",
443
+ "execution_count": null,
444
+ "metadata": {
445
+ "id": "pcv6KXUD99qA"
446
+ },
447
+ "outputs": [],
448
+ "source": [
449
+ "from trl import GRPOConfig\n",
450
+ "\n",
451
+ "output_dir = \"Ministral-3-3B-Instruct-trl-grpo\"\n",
452
+ "\n",
453
+ "# Configure training arguments using GRPOConfig\n",
454
+ "training_args = GRPOConfig(\n",
455
+ " learning_rate=2e-5,\n",
456
+ " #num_train_epochs=1,\n",
457
+ " max_steps=100, # Number of dataset passes. For full trainings, use `num_train_epochs` instead\n",
458
+ "\n",
459
+ " # Parameters that control the data preprocessing\n",
460
+ " per_device_train_batch_size=2,\n",
461
+ " max_completion_length=1024, # default: 256 # Max completion length produced during training\n",
462
+ " num_generations=2, # 2, # default: 8 # Number of generations produced during training for comparison\n",
463
+ "\n",
464
+ " fp16=False,\n",
465
+ " bf16=False,\n",
466
+ "\n",
467
+ " # Parameters related to reporting and saving\n",
468
+ " output_dir=output_dir, # Where to save model checkpoints and logs\n",
469
+ " logging_steps=1, # Log training metrics every N steps\n",
470
+ " report_to=\"trackio\", # Experiment tracking tool\n",
471
+ " trackio_space_id = output_dir,\n",
472
+ "\n",
473
+ " # Hub integration\n",
474
+ " push_to_hub=True,\n",
475
+ " log_completions=True,\n",
476
+ ")"
477
+ ]
478
+ },
479
+ {
480
+ "cell_type": "markdown",
481
+ "metadata": {
482
+ "id": "O0q3myQg927v"
483
+ },
484
+ "source": [
485
+ "Configure the GRPO Trainer. We pass the previously configured `training_args`. We don't use eval dataset to maintain memory usage low but you can configure it."
486
+ ]
487
+ },
488
+ {
489
+ "cell_type": "code",
490
+ "execution_count": null,
491
+ "metadata": {
492
+ "id": "-zd7s5Cs99qA"
493
+ },
494
+ "outputs": [],
495
+ "source": [
496
+ "from trl import GRPOTrainer\n",
497
+ "\n",
498
+ "trainer = GRPOTrainer(\n",
499
+ " model=model,\n",
500
+ " reward_funcs=[format_reward, len_reward],\n",
501
+ " args=training_args,\n",
502
+ " train_dataset=train_dataset,\n",
503
+ " peft_config=peft_config,\n",
504
+ ")"
505
+ ]
506
+ },
507
+ {
508
+ "cell_type": "markdown",
509
+ "metadata": {
510
+ "id": "kQC7Q5kg95xq"
511
+ },
512
+ "source": [
513
+ "Show memory stats before training"
514
+ ]
515
+ },
516
+ {
517
+ "cell_type": "code",
518
+ "execution_count": null,
519
+ "metadata": {
520
+ "id": "iF7cnD0T99qA"
521
+ },
522
+ "outputs": [],
523
+ "source": [
524
+ "gpu_stats = torch.cuda.get_device_properties(0)\n",
525
+ "start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n",
526
+ "max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)\n",
527
+ "\n",
528
+ "print(f\"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.\")\n",
529
+ "print(f\"{start_gpu_memory} GB of memory reserved.\")"
530
+ ]
531
+ },
532
+ {
533
+ "cell_type": "markdown",
534
+ "metadata": {
535
+ "id": "YazYtLAe97Dc"
536
+ },
537
+ "source": [
538
+ "And train!"
539
+ ]
540
+ },
541
+ {
542
+ "cell_type": "code",
543
+ "execution_count": null,
544
+ "metadata": {
545
+ "id": "Ynhxdv3a99qA"
546
+ },
547
+ "outputs": [],
548
+ "source": [
549
+ "trainer_stats = trainer.train()"
550
+ ]
551
+ },
552
+ {
553
+ "cell_type": "markdown",
554
+ "metadata": {
555
+ "id": "SmcYN5yW99IP"
556
+ },
557
+ "source": [
558
+ "Show memory stats after training"
559
+ ]
560
+ },
561
+ {
562
+ "cell_type": "code",
563
+ "execution_count": null,
564
+ "metadata": {
565
+ "id": "mi-exH7699qA"
566
+ },
567
+ "outputs": [],
568
+ "source": [
569
+ "used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n",
570
+ "used_memory_for_lora = round(used_memory - start_gpu_memory, 3)\n",
571
+ "used_percentage = round(used_memory / max_memory * 100, 3)\n",
572
+ "lora_percentage = round(used_memory_for_lora / max_memory * 100, 3)\n",
573
+ "\n",
574
+ "print(f\"{trainer_stats.metrics['train_runtime']} seconds used for training.\")\n",
575
+ "print(f\"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training.\")\n",
576
+ "print(f\"Peak reserved memory = {used_memory} GB.\")\n",
577
+ "print(f\"Peak reserved memory for training = {used_memory_for_lora} GB.\")\n",
578
+ "print(f\"Peak reserved memory % of max memory = {used_percentage} %.\")\n",
579
+ "print(f\"Peak reserved memory for training % of max memory = {lora_percentage} %.\")"
580
+ ]
581
+ },
582
+ {
583
+ "cell_type": "markdown",
584
+ "metadata": {
585
+ "id": "saarW87Y9_-R"
586
+ },
587
+ "source": [
588
+ "## Saving fine tuned model\n",
589
+ "\n",
590
+ "In this step, we save the fine-tuned model both **locally** and to the **Hugging Face Hub** using the credentials from your account."
591
+ ]
592
+ },
593
+ {
594
+ "cell_type": "code",
595
+ "execution_count": null,
596
+ "metadata": {
597
+ "id": "m3mlwQl699qA"
598
+ },
599
+ "outputs": [],
600
+ "source": [
601
+ "trainer.save_model(output_dir)\n",
602
+ "trainer.push_to_hub(dataset_name=dataset_id)"
603
+ ]
604
+ },
605
+ {
606
+ "cell_type": "markdown",
607
+ "metadata": {
608
+ "id": "nfqvO0qw-OvS"
609
+ },
610
+ "source": [
611
+ "## Load the fine-tuned model and run inference\n",
612
+ "\n",
613
+ "Now, let's test our fine-tuned model by loading the **LoRA/QLoRA adapter** and performing **inference**. We'll start by loading the **base model**, then attach the adapter to it, creating the final fine-tuned model ready for evaluation."
614
+ ]
615
+ },
616
+ {
617
+ "cell_type": "code",
618
+ "execution_count": null,
619
+ "metadata": {
620
+ "id": "B7usNBq699qA"
621
+ },
622
+ "outputs": [],
623
+ "source": [
624
+ "from transformers import Mistral3ForConditionalGeneration, MistralCommonBackend\n",
625
+ "from peft import PeftModel\n",
626
+ "\n",
627
+ "base_model = model_name\n",
628
+ "adapter_model = f\"{output_dir}\" # Replace with your HF username or organization\n",
629
+ "\n",
630
+ "model = Mistral3ForConditionalGeneration.from_pretrained(base_model, dtype=\"float32\", device_map=\"auto\")\n",
631
+ "model = PeftModel.from_pretrained(model, adapter_model)\n",
632
+ "\n",
633
+ "tokenizer = MistralCommonBackend.from_pretrained(base_model)"
634
+ ]
635
+ },
636
+ {
637
+ "cell_type": "code",
638
+ "execution_count": null,
639
+ "metadata": {
640
+ "id": "XnIOkXfy99qA"
641
+ },
642
+ "outputs": [],
643
+ "source": [
644
+ "train_dataset[0]"
645
+ ]
646
+ },
647
+ {
648
+ "cell_type": "code",
649
+ "execution_count": null,
650
+ "metadata": {
651
+ "id": "0le5gBl_99qA"
652
+ },
653
+ "outputs": [],
654
+ "source": [
655
+ "from datasets import load_dataset\n",
656
+ "import base64\n",
657
+ "from io import BytesIO\n",
658
+ "\n",
659
+ "dataset_id = 'lmms-lab/multimodal-open-r1-8k-verified'\n",
660
+ "train_dataset = load_dataset(dataset_id, split='train[:5%]')\n",
661
+ "\n",
662
+ "problem = train_dataset[0]['problem']\n",
663
+ "image = train_dataset[0]['image']\n",
664
+ "\n",
665
+ "buffer = BytesIO()\n",
666
+ "image.save(buffer, format=\"JPEG\")\n",
667
+ "image_bytes = buffer.getvalue()\n",
668
+ "image_b64 = base64.b64encode(image_bytes).decode(\"utf-8\")\n",
669
+ "\n",
670
+ "messages = [\n",
671
+ " {\n",
672
+ " \"role\": \"system\", \"content\": [\n",
673
+ " {\"type\": \"text\", \"text\": SYSTEM_PROMPT}\n",
674
+ " ]\n",
675
+ " },\n",
676
+ " {\n",
677
+ " \"role\": \"user\",\n",
678
+ " \"content\": [\n",
679
+ " {\n",
680
+ " \"type\": \"image_url\",\n",
681
+ " \"image_url\": {\n",
682
+ " \"url\": f\"data:image/jpeg;base64,{image_b64}\"\n",
683
+ " },\n",
684
+ " },\n",
685
+ " {\"type\": \"text\", \"text\": problem},\n",
686
+ " ],\n",
687
+ " },\n",
688
+ "]"
689
+ ]
690
+ },
691
+ {
692
+ "cell_type": "code",
693
+ "execution_count": null,
694
+ "metadata": {
695
+ "id": "f9PgBCD499qA"
696
+ },
697
+ "outputs": [],
698
+ "source": [
699
+ "messages"
700
+ ]
701
+ },
702
+ {
703
+ "cell_type": "code",
704
+ "execution_count": null,
705
+ "metadata": {
706
+ "id": "ENOGILKk99qA"
707
+ },
708
+ "outputs": [],
709
+ "source": [
710
+ "import torch\n",
711
+ "\n",
712
+ "tokenized = tokenizer.apply_chat_template(messages, return_tensors=\"pt\", return_dict=True)\n",
713
+ "tokenized[\"input_ids\"] = tokenized[\"input_ids\"].to(device=\"cuda\")\n",
714
+ "tokenized[\"pixel_values\"] = tokenized[\"pixel_values\"].to(dtype=torch.bfloat16, device=\"cuda\")\n",
715
+ "image_sizes = [tokenized[\"pixel_values\"].shape[-2:]]\n",
716
+ "\n",
717
+ "output = model.generate(\n",
718
+ " **tokenized,\n",
719
+ " image_sizes=image_sizes,\n",
720
+ " max_new_tokens=512,\n",
721
+ ")[0]\n",
722
+ "\n",
723
+ "decoded_output = tokenizer.decode(output[len(tokenized[\"input_ids\"][0]):])\n",
724
+ "print(decoded_output)"
725
+ ]
726
+ }
727
+ ],
728
+ "metadata": {
729
+ "accelerator": "GPU",
730
+ "colab": {
731
+ "gpuType": "T4",
732
+ "provenance": []
733
+ },
734
+ "language_info": {
735
+ "name": "python"
736
+ }
737
+ },
738
+ "nbformat": 4,
739
+ "nbformat_minor": 0
740
+ }
ICL/RL/trl_source/examples/notebooks/grpo_qwen3_vl.ipynb ADDED
@@ -0,0 +1,693 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "id": "-J8iGzLf4rUJ"
7
+ },
8
+ "source": [
9
+ "# GRPO Qwen3-VL with QLoRA using TRL\n",
10
+ "\n",
11
+ "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/notebooks/grpo_qwen3_vl.ipynb)\n",
12
+ "\n",
13
+ "![trl banner](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl_banner_dark.png)\n",
14
+ "\n",
15
+ "\n",
16
+ "With [**Transformers Reinforcement Learning (TRL)**](https://github.com/huggingface/trl), you can fine-tune cutting edge vision language models. It comes with support for quantized parameter efficient fine-tuning technique **QLoRA**, so we can use free Colab (T4 GPU) to fine-tune models like [Qwen3-VL](https://huggingface.co/collections/Qwen/qwen3-vl-68d2a7c1b8a8afce4ebd2dbe).\n",
17
+ "\n",
18
+ "\n",
19
+ "- [TRL GitHub Repository](https://github.com/huggingface/trl) — star us to support the project! \n",
20
+ "- [Official TRL Examples](https://huggingface.co/docs/trl/example_overview) \n",
21
+ "- [Community Tutorials](https://huggingface.co/docs/trl/community_tutorials)\n",
22
+ "- [More Qwen3-VL Fine-tuning Examples (including TRL scripts)](https://github.com/QwenLM/Qwen3-VL/tree/main/qwen-vl-finetune/)"
23
+ ]
24
+ },
25
+ {
26
+ "cell_type": "markdown",
27
+ "metadata": {
28
+ "id": "NvrzGRnu48Vz"
29
+ },
30
+ "source": [
31
+ "## Install dependencies\n",
32
+ "\n",
33
+ "We'll install **TRL** with the **PEFT** extra, which ensures all main dependencies such as **Transformers** and **PEFT** (a package for parameter-efficient fine-tuning, e.g., LoRA/QLoRA) are included. Additionally, we'll install **trackio** to log and monitor our experiments, and **bitsandbytes** to enable quantization of LLMs, reducing memory consumption for both inference and training."
34
+ ]
35
+ },
36
+ {
37
+ "cell_type": "code",
38
+ "execution_count": null,
39
+ "metadata": {
40
+ "id": "8CfZlUevmkg7"
41
+ },
42
+ "outputs": [],
43
+ "source": [
44
+ "!pip install -Uq \"trl[peft]\" bitsandbytes trackio math_verify"
45
+ ]
46
+ },
47
+ {
48
+ "cell_type": "markdown",
49
+ "metadata": {
50
+ "id": "gpzI6omi7728"
51
+ },
52
+ "source": [
53
+ "### Log in to Hugging Face\n",
54
+ "\n",
55
+ "Log in to your **Hugging Face** account to save your fine-tuned model, track your experiment results directly on the Hub or access gated models. You can find your **access token** on your [account settings page](https://huggingface.co/settings/tokens)."
56
+ ]
57
+ },
58
+ {
59
+ "cell_type": "code",
60
+ "execution_count": null,
61
+ "metadata": {
62
+ "id": "4Ncx0wYtnYCW"
63
+ },
64
+ "outputs": [],
65
+ "source": [
66
+ "from huggingface_hub import notebook_login\n",
67
+ "\n",
68
+ "notebook_login()"
69
+ ]
70
+ },
71
+ {
72
+ "cell_type": "markdown",
73
+ "metadata": {
74
+ "id": "V_Zylc4t79-n"
75
+ },
76
+ "source": [
77
+ "## Load dataset\n",
78
+ "\n",
79
+ "\n",
80
+ "We'll load the [**lmms-lab/multimodal-open-r1-8k-verified**](https://huggingface.co/datasets/lmms-lab/multimodal-open-r1-8k-verified) dataset from the Hugging Face Hub using the `datasets` library.\n",
81
+ "\n",
82
+ "This dataset contains maths problems with the image representing the problem, along with the solution in thinking format specially tailored for VLMs. By training our model with this dataset, it'll improve its maths and thinking reasoning.\n"
83
+ ]
84
+ },
85
+ {
86
+ "cell_type": "code",
87
+ "execution_count": null,
88
+ "metadata": {
89
+ "id": "TzXogU24F_QR"
90
+ },
91
+ "outputs": [],
92
+ "source": [
93
+ "from datasets import load_dataset\n",
94
+ "\n",
95
+ "dataset_id = 'lmms-lab/multimodal-open-r1-8k-verified'\n",
96
+ "train_dataset = load_dataset(dataset_id, split='train[:5%]')"
97
+ ]
98
+ },
99
+ {
100
+ "cell_type": "markdown",
101
+ "metadata": {
102
+ "id": "gVV7RoRN8zk5"
103
+ },
104
+ "source": [
105
+ "In addition to the `problem` and `image` columns, we also include a custom system prompt to tell the model how we'd like the generation.\n",
106
+ "\n",
107
+ "The system prompt is extracted from DeepSeek R1. Refer to [this previous recipe](https://huggingface.co/learn/cookbook/fine_tuning_llm_grpo_trl) for more details.\n",
108
+ "\n",
109
+ "We convert the dataset samples into conversation samples, including the system prompt and one image and problem description per sample, since this is how the GRPO trainer expects them.\n",
110
+ "\n",
111
+ "We also set `padding_side=\"left\"` to ensure that generated completions during training are concatenated directly after the prompt, which is essential for GRPO to correctly compare token-level probabilities between preferred and rejected responses."
112
+ ]
113
+ },
114
+ {
115
+ "cell_type": "code",
116
+ "execution_count": null,
117
+ "metadata": {
118
+ "id": "ZT1JfiiTGExB"
119
+ },
120
+ "outputs": [],
121
+ "source": [
122
+ "from transformers import AutoProcessor\n",
123
+ "\n",
124
+ "model_name = \"Qwen/Qwen3-VL-4B-Instruct\" # \"Qwen/Qwen3-VL-8B-Instruct\"\n",
125
+ "processor = AutoProcessor.from_pretrained(model_name, padding_side=\"left\")\n",
126
+ "\n",
127
+ "SYSTEM_PROMPT = (\n",
128
+ " \"You are a helpful AI Assistant that provides well-reasoned and detailed responses. \"\n",
129
+ " \"You first think about the reasoning process as an internal monologue and then provide the user with the answer. \"\n",
130
+ " \"Respond in the following format: <think>\\n...\\n</think>\\n<answer>\\n...\\n</answer>\"\n",
131
+ ")\n",
132
+ "\n",
133
+ "\n",
134
+ "def make_conversation(example):\n",
135
+ " conversation = [\n",
136
+ " {\n",
137
+ " \"role\": \"system\",\n",
138
+ " \"content\": [{\"type\": \"text\", \"text\": SYSTEM_PROMPT}],\n",
139
+ " },\n",
140
+ " {\n",
141
+ " \"role\": \"user\",\n",
142
+ " \"content\": [\n",
143
+ " {\"type\": \"image\", \"image\": example[\"image\"]},\n",
144
+ " {\"type\": \"text\", \"text\": example[\"problem\"]},\n",
145
+ " ],\n",
146
+ " },\n",
147
+ " ]\n",
148
+ " prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)\n",
149
+ " return {\n",
150
+ " \"prompt\": prompt,\n",
151
+ " \"image\": example[\"image\"],\n",
152
+ " }\n",
153
+ "\n",
154
+ "train_dataset = train_dataset.map(make_conversation)"
155
+ ]
156
+ },
157
+ {
158
+ "cell_type": "markdown",
159
+ "metadata": {
160
+ "id": "5txAuMAa8ock"
161
+ },
162
+ "source": [
163
+ "Let's review one example to understand the internal structure:"
164
+ ]
165
+ },
166
+ {
167
+ "cell_type": "code",
168
+ "execution_count": null,
169
+ "metadata": {
170
+ "id": "PDXQd5Jk2Bqe"
171
+ },
172
+ "outputs": [],
173
+ "source": [
174
+ "train_dataset[0]"
175
+ ]
176
+ },
177
+ {
178
+ "cell_type": "code",
179
+ "execution_count": null,
180
+ "metadata": {
181
+ "id": "hzSR_56wxKDA"
182
+ },
183
+ "outputs": [],
184
+ "source": [
185
+ "train_dataset = train_dataset.remove_columns(['problem', 'original_question', 'original_answer'])"
186
+ ]
187
+ },
188
+ {
189
+ "cell_type": "code",
190
+ "execution_count": null,
191
+ "metadata": {
192
+ "id": "T9rCkeqDODba"
193
+ },
194
+ "outputs": [],
195
+ "source": [
196
+ "train_dataset[0]"
197
+ ]
198
+ },
199
+ {
200
+ "cell_type": "markdown",
201
+ "metadata": {
202
+ "id": "YY3uMp909Eqy"
203
+ },
204
+ "source": [
205
+ "## Load model and configure LoRA/QLoRA\n",
206
+ "\n",
207
+ "This notebook can be used with two fine-tuning methods. By default, it is set up for **QLoRA**, which includes quantization using `BitsAndBytesConfig`. If you prefer to use standard **LoRA** without quantization, simply comment out the `BitsAndBytesConfig` configuration."
208
+ ]
209
+ },
210
+ {
211
+ "cell_type": "code",
212
+ "execution_count": null,
213
+ "metadata": {
214
+ "id": "gt05dgXgm9QR"
215
+ },
216
+ "outputs": [],
217
+ "source": [
218
+ "from transformers import Qwen3VLForConditionalGeneration, BitsAndBytesConfig\n",
219
+ "import torch\n",
220
+ "\n",
221
+ "model = Qwen3VLForConditionalGeneration.from_pretrained(\n",
222
+ " model_name, dtype=\"float32\",\n",
223
+ " device_map=\"auto\",\n",
224
+ " quantization_config=BitsAndBytesConfig(\n",
225
+ " load_in_4bit=True,\n",
226
+ " bnb_4bit_use_double_quant=True,\n",
227
+ " bnb_4bit_quant_type=\"nf4\",\n",
228
+ " bnb_4bit_compute_dtype=torch.float16\n",
229
+ " ),\n",
230
+ ")"
231
+ ]
232
+ },
233
+ {
234
+ "cell_type": "markdown",
235
+ "metadata": {
236
+ "id": "WZGf-GF09Gsc"
237
+ },
238
+ "source": [
239
+ "The following cell defines LoRA (or QLoRA if needed). When training with LoRA/QLoRA, we use a **base model** (the one selected above) and, instead of modifying its original weights, we fine-tune a **LoRA adapter** — a lightweight layer that enables efficient and memory-friendly training. The **`target_modules`** specify which parts of the model (e.g., attention or projection layers) will be adapted by LoRA during fine-tuning."
240
+ ]
241
+ },
242
+ {
243
+ "cell_type": "code",
244
+ "execution_count": null,
245
+ "metadata": {
246
+ "id": "ME1im5gh2LFg"
247
+ },
248
+ "outputs": [],
249
+ "source": [
250
+ "from peft import LoraConfig\n",
251
+ "\n",
252
+ "# You may need to update `target_modules` depending on the architecture of your chosen model.\n",
253
+ "# For example, different VLMs might have different attention/projection layer names.\n",
254
+ "peft_config = LoraConfig(\n",
255
+ " r=8,\n",
256
+ " lora_alpha=32,\n",
257
+ " lora_dropout=0.1,\n",
258
+ " target_modules=[\"q_proj\", \"v_proj\"],\n",
259
+ ")"
260
+ ]
261
+ },
262
+ {
263
+ "cell_type": "markdown",
264
+ "metadata": {
265
+ "id": "mDq4V6dN9MGk"
266
+ },
267
+ "source": [
268
+ "## Train model\n",
269
+ "\n",
270
+ "We'll configure **GRPO** using `GRPOConfig`, keeping the parameters minimal so the training fits on a free Colab instance. You can adjust these settings if more resources are available. For full details on all available parameters, check the [TRL GRPOConfig documentation](https://huggingface.co/docs/trl/sft_trainer#trl.GRPOConfig).\n",
271
+ "\n",
272
+ "First, we need to define the rewards functions that the training algorithm will use to improve the model. In this case, we'll include two reward functions.\n",
273
+ "We'll use a format reward that will reward the model when the output includes `<think>` and `<answer>` tags and additionally a length-based reward to discourage overthinking. Both functions have been extracted from [here](https://github.com/huggingface/open-r1/blob/main/src/open_r1/rewards.py)."
274
+ ]
275
+ },
276
+ {
277
+ "cell_type": "code",
278
+ "execution_count": null,
279
+ "metadata": {
280
+ "id": "Dqp3TfUwHUxW"
281
+ },
282
+ "outputs": [],
283
+ "source": [
284
+ "import re\n",
285
+ "\n",
286
+ "def format_reward(completions, **kwargs):\n",
287
+ " \"\"\"Reward function that checks if the reasoning process is enclosed within <think> and </think> tags, while the final answer is enclosed within <answer> and </answer> tags.\"\"\"\n",
288
+ " pattern = r\"^<think>\\n.*?\\n</think>\\n<answer>\\n.*?\\n</answer>$\"\n",
289
+ " matches = [re.match(pattern, content, re.DOTALL | re.MULTILINE) for content in completions]\n",
290
+ " return [1.0 if match else 0.0 for match in matches]"
291
+ ]
292
+ },
293
+ {
294
+ "cell_type": "code",
295
+ "execution_count": null,
296
+ "metadata": {
297
+ "id": "rxNPUp7RBFcz"
298
+ },
299
+ "outputs": [],
300
+ "source": [
301
+ "from math_verify import LatexExtractionConfig, parse, verify\n",
302
+ "from latex2sympy2_extended import NormalizationConfig\n",
303
+ "\n",
304
+ "\n",
305
+ "def len_reward(completions, solution, **kwargs) -> float:\n",
306
+ " \"\"\"Compute length-based rewards to discourage overthinking and promote token efficiency.\n",
307
+ "\n",
308
+ " Taken from the Kimi 1.5 tech report: https://huggingface.co/papers/2501.12599\n",
309
+ "\n",
310
+ " Args:\n",
311
+ " completions: List of model completions\n",
312
+ " solution: List of ground truth solutions\n",
313
+ "\n",
314
+ " Returns:\n",
315
+ " List of rewards where:\n",
316
+ " - For correct answers: reward = 0.5 - (len - min_len)/(max_len - min_len)\n",
317
+ " - For incorrect answers: reward = min(0, 0.5 - (len - min_len)/(max_len - min_len))\n",
318
+ " \"\"\"\n",
319
+ " contents = completions\n",
320
+ "\n",
321
+ " # First check correctness of answers\n",
322
+ " correctness = []\n",
323
+ " for content, sol in zip(contents, solution):\n",
324
+ " gold_parsed = parse(\n",
325
+ " sol,\n",
326
+ " extraction_mode=\"first_match\",\n",
327
+ " extraction_config=[LatexExtractionConfig()],\n",
328
+ " )\n",
329
+ " if len(gold_parsed) == 0:\n",
330
+ " # Skip unparsable examples\n",
331
+ " correctness.append(True) # Treat as correct to avoid penalizing\n",
332
+ " print(\"Failed to parse gold solution: \", sol)\n",
333
+ " continue\n",
334
+ "\n",
335
+ " answer_parsed = parse(\n",
336
+ " content,\n",
337
+ " extraction_config=[\n",
338
+ " LatexExtractionConfig(\n",
339
+ " normalization_config=NormalizationConfig(\n",
340
+ " nits=False,\n",
341
+ " malformed_operators=False,\n",
342
+ " basic_latex=True,\n",
343
+ " equations=True,\n",
344
+ " boxed=True,\n",
345
+ " units=True,\n",
346
+ " ),\n",
347
+ " boxed_match_priority=0,\n",
348
+ " try_extract_without_anchor=False,\n",
349
+ " )\n",
350
+ " ],\n",
351
+ " extraction_mode=\"first_match\",\n",
352
+ " )\n",
353
+ " correctness.append(verify(answer_parsed, gold_parsed))\n",
354
+ "\n",
355
+ " # Calculate lengths\n",
356
+ " lengths = [len(content) for content in contents]\n",
357
+ " min_len = min(lengths)\n",
358
+ " max_len = max(lengths)\n",
359
+ "\n",
360
+ " # If all responses have the same length, return zero rewards\n",
361
+ " if max_len == min_len:\n",
362
+ " return [0.0] * len(completions)\n",
363
+ "\n",
364
+ " rewards = []\n",
365
+ " for length, is_correct in zip(lengths, correctness):\n",
366
+ " lambda_val = 0.5 - (length - min_len) / (max_len - min_len)\n",
367
+ "\n",
368
+ " if is_correct:\n",
369
+ " reward = lambda_val\n",
370
+ " else:\n",
371
+ " reward = min(0, lambda_val)\n",
372
+ "\n",
373
+ " rewards.append(float(reward))\n",
374
+ "\n",
375
+ " return rewards\n"
376
+ ]
377
+ },
378
+ {
379
+ "cell_type": "markdown",
380
+ "metadata": {
381
+ "id": "9xBL7Rni9LZb"
382
+ },
383
+ "source": [
384
+ "After defining the reward function(s), we can define the `GRPOConfig`."
385
+ ]
386
+ },
387
+ {
388
+ "cell_type": "code",
389
+ "execution_count": null,
390
+ "metadata": {
391
+ "id": "OEmRM0rIHXQ4"
392
+ },
393
+ "outputs": [],
394
+ "source": [
395
+ "from trl import GRPOConfig\n",
396
+ "\n",
397
+ "output_dir = \"Qwen3-VL-4B-Instruct-trl-grpo\"\n",
398
+ "\n",
399
+ "# Configure training arguments using GRPOConfig\n",
400
+ "training_args = GRPOConfig(\n",
401
+ " learning_rate=2e-5,\n",
402
+ " #num_train_epochs=1,\n",
403
+ " max_steps=100, # Number of dataset passes. For full trainings, use `num_train_epochs` instead\n",
404
+ "\n",
405
+ " # Parameters that control the data preprocessing\n",
406
+ " per_device_train_batch_size=2,\n",
407
+ " max_completion_length=1024, # default: 256 # Max completion length produced during training\n",
408
+ " num_generations=2, # 2, # default: 8 # Number of generations produced during training for comparison\n",
409
+ "\n",
410
+ " fp16=True,\n",
411
+ "\n",
412
+ " # Parameters related to reporting and saving\n",
413
+ " output_dir=output_dir, # Where to save model checkpoints and logs\n",
414
+ " logging_steps=1, # Log training metrics every N steps\n",
415
+ " report_to=\"trackio\", # Experiment tracking tool\n",
416
+ "\n",
417
+ " # Hub integration\n",
418
+ " push_to_hub=True,\n",
419
+ " log_completions=True\n",
420
+ ")"
421
+ ]
422
+ },
423
+ {
424
+ "cell_type": "markdown",
425
+ "metadata": {
426
+ "id": "O0q3myQg927v"
427
+ },
428
+ "source": [
429
+ "Configure the GRPO Trainer. We pass the previously configured `training_args`. We don't use eval dataset to maintain memory usage low but you can configure it."
430
+ ]
431
+ },
432
+ {
433
+ "cell_type": "code",
434
+ "execution_count": null,
435
+ "metadata": {
436
+ "colab": {
437
+ "base_uri": "https://localhost:8080/"
438
+ },
439
+ "id": "z5JxkmS9HqD5",
440
+ "outputId": "2b39338e-2194-4829-fc54-5e286566fd28"
441
+ },
442
+ "outputs": [
443
+ {
444
+ "name": "stderr",
445
+ "output_type": "stream",
446
+ "text": [
447
+ "/usr/local/lib/python3.12/dist-packages/peft/mapping_func.py:73: UserWarning: You are trying to modify a model with PEFT for a second time. If you want to reload the model with a different config, make sure to call `.unload()` before.\n",
448
+ " warnings.warn(\n",
449
+ "/usr/local/lib/python3.12/dist-packages/peft/tuners/tuners_utils.py:196: UserWarning: Already found a `peft_config` attribute in the model. This will lead to having multiple adapters in the model. Make sure to know what you are doing!\n",
450
+ " warnings.warn(\n"
451
+ ]
452
+ }
453
+ ],
454
+ "source": [
455
+ "from trl import GRPOTrainer\n",
456
+ "\n",
457
+ "trainer = GRPOTrainer(\n",
458
+ " model=model,\n",
459
+ " reward_funcs=[format_reward, len_reward],\n",
460
+ " args=training_args,\n",
461
+ " train_dataset=train_dataset,\n",
462
+ " peft_config=peft_config,\n",
463
+ ")"
464
+ ]
465
+ },
466
+ {
467
+ "cell_type": "markdown",
468
+ "metadata": {
469
+ "id": "kQC7Q5kg95xq"
470
+ },
471
+ "source": [
472
+ "Show memory stats before training"
473
+ ]
474
+ },
475
+ {
476
+ "cell_type": "code",
477
+ "execution_count": null,
478
+ "metadata": {
479
+ "id": "naG_7qlYyBP6"
480
+ },
481
+ "outputs": [],
482
+ "source": [
483
+ "gpu_stats = torch.cuda.get_device_properties(0)\n",
484
+ "start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n",
485
+ "max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)\n",
486
+ "\n",
487
+ "print(f\"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.\")\n",
488
+ "print(f\"{start_gpu_memory} GB of memory reserved.\")"
489
+ ]
490
+ },
491
+ {
492
+ "cell_type": "markdown",
493
+ "metadata": {
494
+ "id": "YazYtLAe97Dc"
495
+ },
496
+ "source": [
497
+ "And train!"
498
+ ]
499
+ },
500
+ {
501
+ "cell_type": "code",
502
+ "execution_count": null,
503
+ "metadata": {
504
+ "id": "pbJXrhA0ywra"
505
+ },
506
+ "outputs": [],
507
+ "source": [
508
+ "trainer_stats = trainer.train()"
509
+ ]
510
+ },
511
+ {
512
+ "cell_type": "markdown",
513
+ "metadata": {
514
+ "id": "SmcYN5yW99IP"
515
+ },
516
+ "source": [
517
+ "Show memory stats after training"
518
+ ]
519
+ },
520
+ {
521
+ "cell_type": "code",
522
+ "execution_count": null,
523
+ "metadata": {
524
+ "id": "TrrwP4ADMmrp"
525
+ },
526
+ "outputs": [],
527
+ "source": [
528
+ "used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n",
529
+ "used_memory_for_lora = round(used_memory - start_gpu_memory, 3)\n",
530
+ "used_percentage = round(used_memory / max_memory * 100, 3)\n",
531
+ "lora_percentage = round(used_memory_for_lora / max_memory * 100, 3)\n",
532
+ "\n",
533
+ "print(f\"{trainer_stats.metrics['train_runtime']} seconds used for training.\")\n",
534
+ "print(f\"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training.\")\n",
535
+ "print(f\"Peak reserved memory = {used_memory} GB.\")\n",
536
+ "print(f\"Peak reserved memory for training = {used_memory_for_lora} GB.\")\n",
537
+ "print(f\"Peak reserved memory % of max memory = {used_percentage} %.\")\n",
538
+ "print(f\"Peak reserved memory for training % of max memory = {lora_percentage} %.\")"
539
+ ]
540
+ },
541
+ {
542
+ "cell_type": "markdown",
543
+ "metadata": {
544
+ "id": "saarW87Y9_-R"
545
+ },
546
+ "source": [
547
+ "## Saving fine tuned model\n",
548
+ "\n",
549
+ "In this step, we save the fine-tuned model both **locally** and to the **Hugging Face Hub** using the credentials from your account."
550
+ ]
551
+ },
552
+ {
553
+ "cell_type": "code",
554
+ "execution_count": null,
555
+ "metadata": {
556
+ "id": "71A8aqEyyETA"
557
+ },
558
+ "outputs": [],
559
+ "source": [
560
+ "trainer.save_model(output_dir)\n",
561
+ "trainer.push_to_hub(dataset_name=dataset_id)"
562
+ ]
563
+ },
564
+ {
565
+ "cell_type": "markdown",
566
+ "metadata": {
567
+ "id": "nfqvO0qw-OvS"
568
+ },
569
+ "source": [
570
+ "## Load the fine-tuned model and run inference\n",
571
+ "\n",
572
+ "Now, let's test our fine-tuned model by loading the **LoRA/QLoRA adapter** and performing **inference**. We'll start by loading the **base model**, then attach the adapter to it, creating the final fine-tuned model ready for evaluation."
573
+ ]
574
+ },
575
+ {
576
+ "cell_type": "code",
577
+ "execution_count": null,
578
+ "metadata": {
579
+ "id": "R8T2uFQVyFeH"
580
+ },
581
+ "outputs": [],
582
+ "source": [
583
+ "from transformers import Qwen3VLForConditionalGeneration, AutoProcessor\n",
584
+ "from peft import PeftModel\n",
585
+ "\n",
586
+ "base_model = model_name\n",
587
+ "adapter_model = f\"{output_dir}\" # Replace with your HF username or organization\n",
588
+ "\n",
589
+ "model = Qwen3VLForConditionalGeneration.from_pretrained(base_model, dtype=\"float32\", device_map=\"auto\")\n",
590
+ "model = PeftModel.from_pretrained(model, adapter_model)\n",
591
+ "\n",
592
+ "processor = AutoProcessor.from_pretrained(base_model)"
593
+ ]
594
+ },
595
+ {
596
+ "cell_type": "code",
597
+ "execution_count": null,
598
+ "metadata": {
599
+ "id": "dPBHP0CpLa6K"
600
+ },
601
+ "outputs": [],
602
+ "source": [
603
+ "train_dataset[0]"
604
+ ]
605
+ },
606
+ {
607
+ "cell_type": "code",
608
+ "execution_count": null,
609
+ "metadata": {
610
+ "id": "cG5-ccGRyHgo"
611
+ },
612
+ "outputs": [],
613
+ "source": [
614
+ "from datasets import load_dataset\n",
615
+ "\n",
616
+ "dataset_id = 'lmms-lab/multimodal-open-r1-8k-verified'\n",
617
+ "train_dataset = load_dataset(dataset_id, split='train[:5%]')\n",
618
+ "\n",
619
+ "problem = train_dataset[0]['problem']\n",
620
+ "image = train_dataset[0]['image']\n",
621
+ "\n",
622
+ "messages = [\n",
623
+ " {\n",
624
+ " \"role\": \"system\", \"content\": [\n",
625
+ " {\"type\": \"text\", \"text\": SYSTEM_PROMPT}\n",
626
+ " ]\n",
627
+ " },\n",
628
+ " {\n",
629
+ " \"role\": \"user\",\n",
630
+ " \"content\": [\n",
631
+ " {\"type\": \"image\", \"image\": image},\n",
632
+ " {\"type\": \"text\", \"text\": problem},\n",
633
+ " ],\n",
634
+ " },\n",
635
+ "]"
636
+ ]
637
+ },
638
+ {
639
+ "cell_type": "code",
640
+ "execution_count": null,
641
+ "metadata": {
642
+ "id": "r_70q_8lLgfV"
643
+ },
644
+ "outputs": [],
645
+ "source": [
646
+ "messages"
647
+ ]
648
+ },
649
+ {
650
+ "cell_type": "code",
651
+ "execution_count": null,
652
+ "metadata": {
653
+ "id": "PX92MjqlyIwB"
654
+ },
655
+ "outputs": [],
656
+ "source": [
657
+ "inputs = processor.apply_chat_template(\n",
658
+ " messages,\n",
659
+ " add_generation_prompt=True,\n",
660
+ " tokenize=True,\n",
661
+ " return_tensors=\"pt\",\n",
662
+ " return_dict=True,\n",
663
+ ").to(model.device)\n",
664
+ "\n",
665
+ "# Inference: Generation of the output\n",
666
+ "generated_ids = model.generate(**inputs, max_new_tokens=500)\n",
667
+ "generated_ids_trimmed = [\n",
668
+ " out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)\n",
669
+ "]\n",
670
+ "output_text = processor.batch_decode(\n",
671
+ " generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False\n",
672
+ ")\n",
673
+ "print(output_text)"
674
+ ]
675
+ }
676
+ ],
677
+ "metadata": {
678
+ "accelerator": "GPU",
679
+ "colab": {
680
+ "gpuType": "T4",
681
+ "provenance": []
682
+ },
683
+ "kernelspec": {
684
+ "display_name": "Python 3",
685
+ "name": "python3"
686
+ },
687
+ "language_info": {
688
+ "name": "python"
689
+ }
690
+ },
691
+ "nbformat": 4,
692
+ "nbformat_minor": 0
693
+ }
ICL/RL/trl_source/examples/notebooks/grpo_rnj_1_instruct.ipynb ADDED
@@ -0,0 +1,622 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "id": "-J8iGzLf4rUJ"
7
+ },
8
+ "source": [
9
+ "# GRPO EssentialAI/rnj-1-instruct with QLoRA using TRL\n",
10
+ "\n",
11
+ "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/notebooks/grpo_rnj_1_instruct.ipynb)\n",
12
+ "\n",
13
+ "![trl banner](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl_banner_dark.png)\n",
14
+ "\n",
15
+ "\n",
16
+ "With [**Transformers Reinforcement Learning (TRL)**](https://github.com/huggingface/trl), you can fine-tune cutting edge large language models. It comes with support for quantized parameter efficient fine-tuning technique **QLoRA**, so we can use Colab to fine-tune models like [EssentialAI/rnj-1-instruct](https://huggingface.co/collections/EssentialAI/rnj-1).\n",
17
+ "\n",
18
+ "\n",
19
+ "- [TRL GitHub Repository](https://github.com/huggingface/trl) — star us to support the project! \n",
20
+ "- [Official TRL Examples](https://huggingface.co/docs/trl/example_overview) \n",
21
+ "- [Community Tutorials](https://huggingface.co/docs/trl/community_tutorials)\n",
22
+ "\n",
23
+ "In this notebook, we'll add reasoning capabilities to the model, teaching it to generate reasoning traces (`<think></think>`) before giving us the final answer (`<answer></answer>`)."
24
+ ]
25
+ },
26
+ {
27
+ "cell_type": "markdown",
28
+ "metadata": {
29
+ "id": "NvrzGRnu48Vz"
30
+ },
31
+ "source": [
32
+ "## Install dependencies\n",
33
+ "\n",
34
+ "We'll install **TRL** with the **PEFT** extra, which ensures all main dependencies such as **Transformers** and **PEFT** (a package for parameter-efficient fine-tuning, e.g., LoRA/QLoRA) are included. Additionally, we'll install **trackio** to log and monitor our experiments, and **bitsandbytes** to enable quantization of LLMs, reducing memory consumption for both inference and training."
35
+ ]
36
+ },
37
+ {
38
+ "cell_type": "code",
39
+ "execution_count": null,
40
+ "metadata": {
41
+ "id": "8VOdRz9fgFa8"
42
+ },
43
+ "outputs": [],
44
+ "source": [
45
+ "!pip install -Uq \"trl[peft]\" bitsandbytes trackio math_verify"
46
+ ]
47
+ },
48
+ {
49
+ "cell_type": "markdown",
50
+ "metadata": {
51
+ "id": "gpzI6omi7728"
52
+ },
53
+ "source": [
54
+ "### Log in to Hugging Face\n",
55
+ "\n",
56
+ "Log in to your **Hugging Face** account to save your fine-tuned model, track your experiment results directly on the Hub or access gated models. You can find your **access token** on your [account settings page](https://huggingface.co/settings/tokens)."
57
+ ]
58
+ },
59
+ {
60
+ "cell_type": "code",
61
+ "execution_count": null,
62
+ "metadata": {
63
+ "id": "d3j3BsdQgFa8"
64
+ },
65
+ "outputs": [],
66
+ "source": [
67
+ "from huggingface_hub import notebook_login\n",
68
+ "\n",
69
+ "notebook_login()"
70
+ ]
71
+ },
72
+ {
73
+ "cell_type": "markdown",
74
+ "metadata": {
75
+ "id": "V_Zylc4t79-n"
76
+ },
77
+ "source": [
78
+ "## Load dataset\n",
79
+ "\n",
80
+ "\n",
81
+ "We'll load the [**AI-MO/NuminaMath-TIR**](https://huggingface.co/datasets/AI-MO/NuminaMath-TIR) dataset from the Hugging Face Hub using the `datasets` library.\n",
82
+ "\n",
83
+ "This dataset contains maths problems, along with the solution in thinking format specially tailored for LLMs. By training our model with this dataset, it'll improve its maths and thinking reasoning.\n",
84
+ "\n",
85
+ "> We only use a subset for educational purposes. In a real scenario, we'd use the complete dataset."
86
+ ]
87
+ },
88
+ {
89
+ "cell_type": "code",
90
+ "execution_count": null,
91
+ "metadata": {
92
+ "id": "YSuLNZAmgFa9"
93
+ },
94
+ "outputs": [],
95
+ "source": [
96
+ "from datasets import load_dataset\n",
97
+ "\n",
98
+ "dataset_id = 'AI-MO/NuminaMath-TIR'\n",
99
+ "train_dataset = load_dataset(dataset_id, split='train[:5%]')"
100
+ ]
101
+ },
102
+ {
103
+ "cell_type": "markdown",
104
+ "metadata": {
105
+ "id": "gVV7RoRN8zk5"
106
+ },
107
+ "source": [
108
+ "In addition to the current columns, we also include a custom system prompt to tell the model how we'd like the generation.\n",
109
+ "\n",
110
+ "This system prompt is an adapted version of the original one extracted from **DeepSeek R1**. For additional background, see [this previous recipe](https://huggingface.co/learn/cookbook/fine_tuning_llm_grpo_trl). We extend the prompt with **examples** and a **more explicit, verbose formulation** to make the desired behavior easier for the model to learn. Depending on your goals, you may further enrich the prompt to simplify learning, or intentionally shorten and harden it to encourage more robust and generalizable behavior.\n",
111
+ "\n",
112
+ "We convert the dataset samples into conversation samples, including the system prompt and problem description per sample, since this is how the GRPO trainer expects them.\n",
113
+ "\n",
114
+ "We also set `padding_side=\"left\"` to ensure that generated completions during training are concatenated directly after the prompt, which is essential for GRPO to correctly compare token-level probabilities between preferred and rejected responses."
115
+ ]
116
+ },
117
+ {
118
+ "cell_type": "code",
119
+ "execution_count": null,
120
+ "metadata": {
121
+ "id": "vr9t-9Z5gFa9"
122
+ },
123
+ "outputs": [],
124
+ "source": [
125
+ "SYSTEM_PROMPT = \"\"\"A conversation between User and Assistant. The user asks a question, and the Assistant solves it.\n",
126
+ "The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.\n",
127
+ "The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags.\n",
128
+ "Use exactly one <think>...</think> block followed by exactly one <answer>...</answer> block.\n",
129
+ "\n",
130
+ "Examples:\n",
131
+ "\n",
132
+ "User: What is 2 + 2?\n",
133
+ "Assistant:\n",
134
+ "<think>\n",
135
+ "I will add 2 and 2 together.\n",
136
+ "</think>\n",
137
+ "<answer>4</answer>\n",
138
+ "\n",
139
+ "User: What is 3 × 5?\n",
140
+ "Assistant:\n",
141
+ "<think>\n",
142
+ "I will multiply 3 by 5.\n",
143
+ "</think>\n",
144
+ "<answer>15</answer>\n",
145
+ "\n",
146
+ "User: Find the GCD of 12 and 18.\n",
147
+ "Assistant:\n",
148
+ "<think>\n",
149
+ "I will list the divisors of 12 and 18 and find the greatest one they have in common.\n",
150
+ "</think>\n",
151
+ "<answer>6</answer>\n",
152
+ "\"\"\"\n",
153
+ "\n",
154
+ "def make_conversation(example):\n",
155
+ " return {\n",
156
+ " \"prompt\": [\n",
157
+ " {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n",
158
+ " {\"role\": \"user\", \"content\": example[\"problem\"]},\n",
159
+ " ],\n",
160
+ " }\n",
161
+ "\n",
162
+ "train_dataset = train_dataset.map(make_conversation)"
163
+ ]
164
+ },
165
+ {
166
+ "cell_type": "markdown",
167
+ "metadata": {
168
+ "id": "5txAuMAa8ock"
169
+ },
170
+ "source": [
171
+ "Let's review one example to understand the internal structure:"
172
+ ]
173
+ },
174
+ {
175
+ "cell_type": "code",
176
+ "execution_count": null,
177
+ "metadata": {
178
+ "id": "jZtkB0D9gFa9"
179
+ },
180
+ "outputs": [],
181
+ "source": [
182
+ "print(train_dataset[0])"
183
+ ]
184
+ },
185
+ {
186
+ "cell_type": "markdown",
187
+ "metadata": {
188
+ "id": "FtdKjmyFZImL"
189
+ },
190
+ "source": [
191
+ "And remove the columns that are not needed for training:"
192
+ ]
193
+ },
194
+ {
195
+ "cell_type": "code",
196
+ "execution_count": null,
197
+ "metadata": {
198
+ "id": "Ai4F1GaPgFa-"
199
+ },
200
+ "outputs": [],
201
+ "source": [
202
+ "train_dataset = train_dataset.remove_columns(['messages', 'problem'])\n",
203
+ "print(train_dataset)"
204
+ ]
205
+ },
206
+ {
207
+ "cell_type": "markdown",
208
+ "metadata": {
209
+ "id": "YY3uMp909Eqy"
210
+ },
211
+ "source": [
212
+ "## Load model and configure LoRA/QLoRA\n",
213
+ "\n",
214
+ "This notebook can be used with two fine-tuning methods. By default, it is set up for **QLoRA**, which includes quantization using `BitsAndBytesConfig`. If you prefer to use standard **LoRA** without quantization, simply comment out the `BitsAndBytesConfig` configuration."
215
+ ]
216
+ },
217
+ {
218
+ "cell_type": "code",
219
+ "execution_count": null,
220
+ "metadata": {
221
+ "id": "DSKcUQ9RgFa-"
222
+ },
223
+ "outputs": [],
224
+ "source": [
225
+ "from transformers import AutoModelForCausalLM, BitsAndBytesConfig\n",
226
+ "import torch\n",
227
+ "\n",
228
+ "model_name = \"EssentialAI/rnj-1-instruct\"\n",
229
+ "\n",
230
+ "model = AutoModelForCausalLM.from_pretrained(\n",
231
+ " model_name,\n",
232
+ " dtype=\"float32\",\n",
233
+ " device_map=\"auto\",\n",
234
+ " quantization_config=BitsAndBytesConfig(\n",
235
+ " load_in_4bit=True,\n",
236
+ " bnb_4bit_use_double_quant=True,\n",
237
+ " bnb_4bit_quant_type=\"nf4\",\n",
238
+ " bnb_4bit_compute_dtype=torch.float16\n",
239
+ " ),\n",
240
+ ")"
241
+ ]
242
+ },
243
+ {
244
+ "cell_type": "markdown",
245
+ "metadata": {
246
+ "id": "WZGf-GF09Gsc"
247
+ },
248
+ "source": [
249
+ "The following cell defines LoRA (or QLoRA if needed). When training with LoRA/QLoRA, we use a **base model** (the one selected above) and, instead of modifying its original weights, we fine-tune a **LoRA adapter**, a lightweight layer that enables efficient and memory-friendly training. The **`target_modules`** specify which parts of the model (e.g., attention or projection layers) will be adapted by LoRA during fine-tuning."
250
+ ]
251
+ },
252
+ {
253
+ "cell_type": "code",
254
+ "execution_count": null,
255
+ "metadata": {
256
+ "id": "nMMlDxJSgFa-"
257
+ },
258
+ "outputs": [],
259
+ "source": [
260
+ "from peft import LoraConfig\n",
261
+ "\n",
262
+ "# You may need to update `target_modules` depending on the architecture of your chosen model.\n",
263
+ "# For example, different LLMs might have different attention/projection layer names.\n",
264
+ "peft_config = LoraConfig(\n",
265
+ " r=32,\n",
266
+ " lora_alpha=32,\n",
267
+ " target_modules = [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\",],\n",
268
+ ")\n"
269
+ ]
270
+ },
271
+ {
272
+ "cell_type": "markdown",
273
+ "metadata": {
274
+ "id": "mDq4V6dN9MGk"
275
+ },
276
+ "source": [
277
+ "## Train model\n",
278
+ "\n",
279
+ "We'll configure **GRPO** using `GRPOConfig`, keeping the parameters minimal so the training fits on a Colab instance. You can adjust these settings depending on the resources available. For full details on all available parameters, check the [TRL GRPOConfig documentation](https://huggingface.co/docs/trl/sft_trainer#trl.GRPOConfig).\n",
280
+ "\n",
281
+ "First, we need to define the rewards functions that the training algorithm will use to improve the model. In this case, we'll include just one reward function.\n",
282
+ "We'll use a format reward that will reward the model when the output includes `<think>` and `<answer>` tags. This is a simplification of the pipeline for educational purposes, but in a real scenario, you'd at least all need a reward function to check the correctness of the model answer. The function has been extracted from [here](https://github.com/huggingface/open-r1/blob/main/src/open_r1/rewards.py).\n",
283
+ "\n",
284
+ "> 💡 **Note**: \n",
285
+ "> You can further refine this reward by making it more granular. For example, assigning partial rewards when `<think>` and `<answer>` appear independently, or when they are present but incorrectly ordered. This can make the learning signal denser and speed up early training. However, overly simplifying the reward may reduce robustness, even if it helps the model converge faster. In practice, there is a trade-off between ease of learning and the generalization quality of the final model."
286
+ ]
287
+ },
288
+ {
289
+ "cell_type": "code",
290
+ "execution_count": null,
291
+ "metadata": {
292
+ "id": "Rtx5owCRgFa-"
293
+ },
294
+ "outputs": [],
295
+ "source": [
296
+ "import re\n",
297
+ "\n",
298
+ "def format_reward(completions, **kwargs):\n",
299
+ " \"\"\"Reward function that checks if the reasoning process is enclosed within <think> and </think> tags, while the final answer is enclosed within <answer> and </answer> tags.\"\"\"\n",
300
+ " pattern = r\"<think>.*?</think>.*?<answer>.*?</answer>\"\n",
301
+ "\n",
302
+ " matches = []\n",
303
+ " for item in completions:\n",
304
+ " if isinstance(item, list):\n",
305
+ " text = item[0]['content']\n",
306
+ " else:\n",
307
+ " text = item\n",
308
+ " match = re.match(pattern, text, re.DOTALL | re.MULTILINE)\n",
309
+ " matches.append(match)\n",
310
+ "\n",
311
+ " return [1.0 if match else 0.0 for match in matches]"
312
+ ]
313
+ },
314
+ {
315
+ "cell_type": "markdown",
316
+ "metadata": {
317
+ "id": "9xBL7Rni9LZb"
318
+ },
319
+ "source": [
320
+ "After defining the reward function(s), we can define the `GRPOConfig`. You can adapt the values in the config depending on your training setting and even fit the training in more constrained setups like free Colab (T4)."
321
+ ]
322
+ },
323
+ {
324
+ "cell_type": "code",
325
+ "execution_count": null,
326
+ "metadata": {
327
+ "id": "rJ0VfG3wgFa-"
328
+ },
329
+ "outputs": [],
330
+ "source": [
331
+ "from trl import GRPOConfig\n",
332
+ "\n",
333
+ "output_dir = \"EssentialAI-rnj-1-instruct-trl-grpo\"\n",
334
+ "\n",
335
+ "# Configure training arguments using GRPOConfig\n",
336
+ "training_args = GRPOConfig(\n",
337
+ " learning_rate=2e-5, # Learning rate used during traing\n",
338
+ " num_train_epochs=1, # Number of full dataset passes. For testing, use `max_steps` instead\n",
339
+ " #max_steps=100,\n",
340
+ "\n",
341
+ " # Parameters that control the data preprocessing\n",
342
+ " per_device_train_batch_size=8,\n",
343
+ " max_completion_length=256, # default: 256 # Max completion length produced during training\n",
344
+ " num_generations=8, # default: 8 # Number of generations produced during training for comparison\n",
345
+ "\n",
346
+ " # Parameters related to reporting and saving\n",
347
+ " output_dir=output_dir, # Where to save model checkpoints and logs\n",
348
+ " logging_steps=10, # Log training metrics every N steps\n",
349
+ " report_to=\"trackio\", # Experiment tracking tool\n",
350
+ " trackio_space_id = output_dir, # HF Space where you trackio will be\n",
351
+ "\n",
352
+ " # Hub integration\n",
353
+ " push_to_hub=True, # Push the resulted model to the Hub\n",
354
+ " log_completions=True, # Log completions during training\n",
355
+ ")"
356
+ ]
357
+ },
358
+ {
359
+ "cell_type": "markdown",
360
+ "metadata": {
361
+ "id": "O0q3myQg927v"
362
+ },
363
+ "source": [
364
+ "Configure the GRPO Trainer. We pass the previously configured `training_args`. We don't use eval dataset to maintain memory usage low but you can configure it."
365
+ ]
366
+ },
367
+ {
368
+ "cell_type": "code",
369
+ "execution_count": null,
370
+ "metadata": {
371
+ "id": "aW7Gi4nXgFa-"
372
+ },
373
+ "outputs": [],
374
+ "source": [
375
+ "from trl import GRPOTrainer\n",
376
+ "\n",
377
+ "trainer = GRPOTrainer(\n",
378
+ " model=model,\n",
379
+ " reward_funcs=[format_reward],\n",
380
+ " args=training_args,\n",
381
+ " train_dataset=train_dataset,\n",
382
+ " peft_config=peft_config,\n",
383
+ ")"
384
+ ]
385
+ },
386
+ {
387
+ "cell_type": "markdown",
388
+ "metadata": {
389
+ "id": "kQC7Q5kg95xq"
390
+ },
391
+ "source": [
392
+ "Show memory stats before training"
393
+ ]
394
+ },
395
+ {
396
+ "cell_type": "code",
397
+ "execution_count": null,
398
+ "metadata": {
399
+ "id": "OJdVlC_mgFa_"
400
+ },
401
+ "outputs": [],
402
+ "source": [
403
+ "gpu_stats = torch.cuda.get_device_properties(0)\n",
404
+ "start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n",
405
+ "max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)\n",
406
+ "\n",
407
+ "print(f\"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.\")\n",
408
+ "print(f\"{start_gpu_memory} GB of memory reserved.\")"
409
+ ]
410
+ },
411
+ {
412
+ "cell_type": "markdown",
413
+ "metadata": {
414
+ "id": "YazYtLAe97Dc"
415
+ },
416
+ "source": [
417
+ "And train!"
418
+ ]
419
+ },
420
+ {
421
+ "cell_type": "code",
422
+ "execution_count": null,
423
+ "metadata": {
424
+ "id": "Mtv8s7rBgFa_"
425
+ },
426
+ "outputs": [],
427
+ "source": [
428
+ "trainer_stats = trainer.train()"
429
+ ]
430
+ },
431
+ {
432
+ "cell_type": "markdown",
433
+ "metadata": {
434
+ "id": "SmcYN5yW99IP"
435
+ },
436
+ "source": [
437
+ "Show memory stats after training"
438
+ ]
439
+ },
440
+ {
441
+ "cell_type": "code",
442
+ "execution_count": null,
443
+ "metadata": {
444
+ "id": "-ROfX8e9gFa_"
445
+ },
446
+ "outputs": [],
447
+ "source": [
448
+ "used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n",
449
+ "used_memory_for_lora = round(used_memory - start_gpu_memory, 3)\n",
450
+ "used_percentage = round(used_memory / max_memory * 100, 3)\n",
451
+ "lora_percentage = round(used_memory_for_lora / max_memory * 100, 3)\n",
452
+ "\n",
453
+ "print(f\"{trainer_stats.metrics['train_runtime']} seconds used for training.\")\n",
454
+ "print(f\"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training.\")\n",
455
+ "print(f\"Peak reserved memory = {used_memory} GB.\")\n",
456
+ "print(f\"Peak reserved memory for training = {used_memory_for_lora} GB.\")\n",
457
+ "print(f\"Peak reserved memory % of max memory = {used_percentage} %.\")\n",
458
+ "print(f\"Peak reserved memory for training % of max memory = {lora_percentage} %.\")"
459
+ ]
460
+ },
461
+ {
462
+ "cell_type": "markdown",
463
+ "metadata": {
464
+ "id": "saarW87Y9_-R"
465
+ },
466
+ "source": [
467
+ "## Saving fine tuned model\n",
468
+ "\n",
469
+ "In this step, we save the fine-tuned model both **locally** and to the **Hugging Face Hub** using the credentials from your account."
470
+ ]
471
+ },
472
+ {
473
+ "cell_type": "code",
474
+ "execution_count": null,
475
+ "metadata": {
476
+ "id": "09zYXJ3GgFa_"
477
+ },
478
+ "outputs": [],
479
+ "source": [
480
+ "trainer.save_model(output_dir)\n",
481
+ "trainer.push_to_hub(dataset_name=dataset_id)"
482
+ ]
483
+ },
484
+ {
485
+ "cell_type": "markdown",
486
+ "metadata": {
487
+ "id": "nfqvO0qw-OvS"
488
+ },
489
+ "source": [
490
+ "## Load the fine-tuned model and run inference\n",
491
+ "\n",
492
+ "Now, let's test our fine-tuned model by loading the **LoRA/QLoRA adapter** and performing **inference**. We'll start by loading the **base model**, then attach the adapter to it, creating the final fine-tuned model ready for evaluation."
493
+ ]
494
+ },
495
+ {
496
+ "cell_type": "code",
497
+ "execution_count": null,
498
+ "metadata": {
499
+ "id": "9Yk9RAABgFa_"
500
+ },
501
+ "outputs": [],
502
+ "source": [
503
+ "output_dir = 'sergiopaniego/EssentialAI-rnj-1-instruct-trl-grpo'\n",
504
+ "model_name = \"EssentialAI/rnj-1-instruct\""
505
+ ]
506
+ },
507
+ {
508
+ "cell_type": "code",
509
+ "execution_count": null,
510
+ "metadata": {
511
+ "id": "CdzlQcCAgFa_"
512
+ },
513
+ "outputs": [],
514
+ "source": [
515
+ "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
516
+ "from peft import PeftModel\n",
517
+ "\n",
518
+ "base_model = model_name\n",
519
+ "adapter_model = f\"{output_dir}\" # Replace with your HF username or organization\n",
520
+ "\n",
521
+ "model = AutoModelForCausalLM.from_pretrained(base_model, dtype=\"float32\", device_map=\"auto\")\n",
522
+ "model = PeftModel.from_pretrained(model, adapter_model)\n",
523
+ "\n",
524
+ "tokenizer = AutoTokenizer.from_pretrained(base_model)"
525
+ ]
526
+ },
527
+ {
528
+ "cell_type": "code",
529
+ "execution_count": null,
530
+ "metadata": {
531
+ "id": "LZgjlAu-gFa_"
532
+ },
533
+ "outputs": [],
534
+ "source": [
535
+ "train_dataset[0]"
536
+ ]
537
+ },
538
+ {
539
+ "cell_type": "code",
540
+ "execution_count": null,
541
+ "metadata": {
542
+ "id": "gjY6TqQHgFa_"
543
+ },
544
+ "outputs": [],
545
+ "source": [
546
+ "from datasets import load_dataset\n",
547
+ "\n",
548
+ "dataset_id = 'AI-MO/NuminaMath-TIR'\n",
549
+ "train_dataset = load_dataset(dataset_id, split='train[:5%]')\n",
550
+ "\n",
551
+ "problem = train_dataset[0]['problem']\n",
552
+ "\n",
553
+ "messages = [\n",
554
+ " {\n",
555
+ " \"role\": \"system\", \"content\": [\n",
556
+ " {\"type\": \"text\", \"text\": SYSTEM_PROMPT}\n",
557
+ " ]\n",
558
+ " },\n",
559
+ " {\n",
560
+ " \"role\": \"user\",\n",
561
+ " \"content\": [\n",
562
+ " {\"type\": \"text\", \"text\": problem},\n",
563
+ " ],\n",
564
+ " },\n",
565
+ "]"
566
+ ]
567
+ },
568
+ {
569
+ "cell_type": "code",
570
+ "execution_count": null,
571
+ "metadata": {
572
+ "id": "eaVubGYmgFa_"
573
+ },
574
+ "outputs": [],
575
+ "source": [
576
+ "messages"
577
+ ]
578
+ },
579
+ {
580
+ "cell_type": "code",
581
+ "execution_count": null,
582
+ "metadata": {
583
+ "id": "2M6Xh4JMgFa_"
584
+ },
585
+ "outputs": [],
586
+ "source": [
587
+ "input_ids = tokenizer.apply_chat_template(\n",
588
+ " messages,\n",
589
+ " add_generation_prompt=True,\n",
590
+ " return_tensors=\"pt\",\n",
591
+ " return_dict=False,\n",
592
+ ").to(model.device)\n",
593
+ "\n",
594
+ "# --- Generate Prediction --- #\n",
595
+ "print(\"Generating prediction...\")\n",
596
+ "output_ids = model.generate(\n",
597
+ " input_ids,\n",
598
+ " max_new_tokens=50,\n",
599
+ " pad_token_id=tokenizer.eos_token_id,\n",
600
+ " do_sample=True,\n",
601
+ " temperature=0.2,\n",
602
+ " top_p=0.95\n",
603
+ ")\n",
604
+ "\n",
605
+ "response = tokenizer.decode(output_ids[0][input_ids.shape[-1]:], skip_special_tokens=True)\n",
606
+ "print(response)"
607
+ ]
608
+ }
609
+ ],
610
+ "metadata": {
611
+ "accelerator": "GPU",
612
+ "colab": {
613
+ "gpuType": "A100",
614
+ "provenance": []
615
+ },
616
+ "language_info": {
617
+ "name": "python"
618
+ }
619
+ },
620
+ "nbformat": 4,
621
+ "nbformat_minor": 0
622
+ }
ICL/RL/trl_source/examples/notebooks/grpo_trl_lora_qlora.ipynb ADDED
@@ -0,0 +1,1638 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "id": "27ozP4Uy-Cz2"
7
+ },
8
+ "source": [
9
+ "# Group Relative Policy Optimization (GRPO) with LoRA/QLoRA using TRL — on a Free Colab Notebook\n",
10
+ "\n",
11
+ "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/notebooks/grpo_trl_lora_qlora.ipynb)"
12
+ ]
13
+ },
14
+ {
15
+ "cell_type": "markdown",
16
+ "metadata": {
17
+ "id": "eOjY4AR1-QnF"
18
+ },
19
+ "source": [
20
+ "![trl banner](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl_banner_dark.png)\n",
21
+ "\n",
22
+ "Easily fine-tune **Large Language Models (LLMs)** or **Vision-Language Models (VLMs)** with **LoRA** or **QLoRA** using the [**Transformers Reinforcement Learning (TRL)**](https://github.com/huggingface/trl) library by Hugging Face and Group Relative Policy Optimization (GRPO) — all within a **free Google Colab notebook** powered by a **T4 GPU**.\n",
23
+ "\n",
24
+ "Thanks to the **built-in memory and training optimizations in TRL**, including LoRA, quantization, gradient checkpointing, and optimized attention kernels, it is possible to **fine-tune a 7B model on a free T4** with a **~7× reduction in memory consumption** compared to naive FP16 training.\n",
25
+ "\n",
26
+ "- [TRL GitHub Repository](https://github.com/huggingface/trl) — star us to support the project! \n",
27
+ "- [Official TRL Examples](https://huggingface.co/docs/trl/example_overview) \n",
28
+ "- [Community Tutorials](https://huggingface.co/docs/trl/community_tutorials)"
29
+ ]
30
+ },
31
+ {
32
+ "cell_type": "markdown",
33
+ "metadata": {
34
+ "id": "w2TnJ6ta-2zj"
35
+ },
36
+ "source": [
37
+ "## Key concepts\n",
38
+ "\n",
39
+ "- **GRPO**: A reinforcement learning algorithm that optimizes a policy by comparing multiple generated responses for the same prompt and updating the model based on their relative rewards, without requiring a separate value model.\n",
40
+ "- **LoRA**: Updates only a few low-rank parameters, reducing training cost and memory.\n",
41
+ "- **QLoRA**: A quantized version of LoRA that enables even larger models to fit on small GPUs.\n",
42
+ "- **TRL**: The Hugging Face library that makes fine-tuning and reinforcement learning simple and efficient.\n",
43
+ "\n",
44
+ "Learn how to perform **GRPO (Group Relative Policy Optimization)** with **LoRA/QLoRA** using **TRL**."
45
+ ]
46
+ },
47
+ {
48
+ "cell_type": "markdown",
49
+ "metadata": {
50
+ "id": "EzScUBxoT4Nt"
51
+ },
52
+ "source": [
53
+ "This table demonstrates how **progressively enabling efficiency techniques** affects **memory usage** and **training throughput** across different hardware configurations. \n",
54
+ "The techniques range from naive FP16 training to **LoRA, quantization, Liger kernels, paged_adamw_8bit, and gradient checkpointing**.\n",
55
+ "\n",
56
+ "| Configuration | LoRA | Quant | Liger | Optimizer | Grad. Ckpt | attn_impl | VRAM (T4) GB | VRAM (A100-40GB)| VRAM (A100-80GB) | Tokens/s (T4) | Tokens/s (A100-40GB) | Tokens/s (A100-80GB) | Status (T4) |\n",
57
+ "|--------------|------|-------|-------|-----------|------------|-----------|---------------|----------------|---------|---------|---------------|------------------|-------------|\n",
58
+ "| **Worst (naive FP16)** | ❌ | ❌ | ❌ | AdamW | ❌ | eager | OOM | OOM | 62 GB | - | - | 0.06 it/s | ❌ |\n",
59
+ "| **Best (all optimizations)** | ✅ | ✅ | ✅ | paged_adamw_8bit | ✅ | sdpa | 9.2 GB | 9.6 GB | 9.6 GB | 0.01 it/s | 0.03 it/s | 0.04 it/s | ✅ |\n",
60
+ "\n",
61
+ "With all efficiency techniques enabled, **memory usage on Colab T4 is reduced by ~7×**, making it possible to **fine-tune a 7B model on free Colab** where naive FP16 training would fail.\n",
62
+ "\n",
63
+ "> A small trade-off in training speed is observed, but the **VRAM reduction is the key enabler**. For faster training on compatible hardware, **vLLM** can also be leveraged.\n",
64
+ "\n",
65
+ "> 💡 Note: For a fair comparison, the number of generations and the batch size were not changed."
66
+ ]
67
+ },
68
+ {
69
+ "cell_type": "markdown",
70
+ "metadata": {
71
+ "id": "9RFq6Op7rjc3"
72
+ },
73
+ "source": [
74
+ "## Install dependencies\n",
75
+ "\n",
76
+ "We'll install **TRL** with the **PEFT** extra, which ensures all main dependencies such as **Transformers** and **PEFT** (a package for parameter-efficient fine-tuning, e.g., LoRA/QLoRA) are included. Additionally, we'll install **trackio** to log and monitor our experiments, **bitsandbytes** to enable quantization of LLMs, reducing memory consumption for both inference and training, and **liger-kernel** for more efficient training."
77
+ ]
78
+ },
79
+ {
80
+ "cell_type": "code",
81
+ "execution_count": null,
82
+ "metadata": {
83
+ "id": "c2jy45nfWbdo"
84
+ },
85
+ "outputs": [],
86
+ "source": [
87
+ "!pip install -Uq \"trl[peft]\" bitsandbytes trackio math_verify liger-kernel"
88
+ ]
89
+ },
90
+ {
91
+ "cell_type": "markdown",
92
+ "metadata": {
93
+ "id": "B33zJG_Q_qb3"
94
+ },
95
+ "source": [
96
+ "### Log in to Hugging Face\n",
97
+ "\n",
98
+ "Log in to your **Hugging Face** account to save your fine-tuned model, track your experiment results directly on the Hub or access gated models. You can find your **access token** on your [account settings page](https://huggingface.co/settings/tokens)."
99
+ ]
100
+ },
101
+ {
102
+ "cell_type": "code",
103
+ "execution_count": null,
104
+ "metadata": {
105
+ "colab": {
106
+ "referenced_widgets": [
107
+ "eec717d21e734c4da066763b4a6add7e"
108
+ ]
109
+ },
110
+ "id": "8zqnTyUDWbdo",
111
+ "outputId": "62d71aaf-352b-4736-acb9-189d78654718"
112
+ },
113
+ "outputs": [],
114
+ "source": [
115
+ "from huggingface_hub import notebook_login\n",
116
+ "\n",
117
+ "notebook_login()"
118
+ ]
119
+ },
120
+ {
121
+ "cell_type": "markdown",
122
+ "metadata": {
123
+ "id": "cTEw4xlFrhnQ"
124
+ },
125
+ "source": [
126
+ "## Load Dataset\n",
127
+ "\n",
128
+ "In this step, we load the [**AI-MO/NuminaMath-TIR**](https://huggingface.co/datasets/AI-MO/NuminaMath-TIR) dataset from the Hugging Face Hub using the `datasets` library.\n",
129
+ "This dataset focuses on **mathematical reasoning**, featuring problems that require step-by-step logical solutions.\n",
130
+ "By fine-tuning a model that does not yet exhibit strong reasoning capabilities, it can learn to **generate structured reasoning steps**, enhancing both the model's **accuracy** and **interpretability** on math-related tasks.\n",
131
+ "\n",
132
+ "For efficiency, we'll load only a **small portion of the training split**:"
133
+ ]
134
+ },
135
+ {
136
+ "cell_type": "code",
137
+ "execution_count": null,
138
+ "metadata": {
139
+ "id": "zU5icx67Wbdp",
140
+ "outputId": "6480b287-dc0e-4e79-feda-f5e4f41d2a82"
141
+ },
142
+ "outputs": [],
143
+ "source": [
144
+ "from datasets import load_dataset\n",
145
+ "\n",
146
+ "dataset_name = 'AI-MO/NuminaMath-TIR'\n",
147
+ "train_dataset = load_dataset(dataset_name, split='train[:5%]')"
148
+ ]
149
+ },
150
+ {
151
+ "cell_type": "markdown",
152
+ "metadata": {
153
+ "id": "P1AIokQrBEGw"
154
+ },
155
+ "source": [
156
+ "Let's check the structure of the dataset"
157
+ ]
158
+ },
159
+ {
160
+ "cell_type": "code",
161
+ "execution_count": null,
162
+ "metadata": {
163
+ "id": "ff6Gx1TWWbdp",
164
+ "outputId": "30d49bed-273a-47d9-d131-a677ca5a8b65"
165
+ },
166
+ "outputs": [
167
+ {
168
+ "name": "stdout",
169
+ "output_type": "stream",
170
+ "text": [
171
+ "Dataset({\n",
172
+ " features: ['problem', 'solution', 'messages'],\n",
173
+ " num_rows: 3622\n",
174
+ "})\n"
175
+ ]
176
+ }
177
+ ],
178
+ "source": [
179
+ "print(train_dataset)"
180
+ ]
181
+ },
182
+ {
183
+ "cell_type": "markdown",
184
+ "metadata": {
185
+ "id": "QY5hkOqDBGns"
186
+ },
187
+ "source": [
188
+ "Let's check one sample:"
189
+ ]
190
+ },
191
+ {
192
+ "cell_type": "code",
193
+ "execution_count": null,
194
+ "metadata": {
195
+ "id": "-y9c7i29Wbdp",
196
+ "outputId": "760662ea-4db4-4b8e-c234-92ae2c8ecc17"
197
+ },
198
+ "outputs": [
199
+ {
200
+ "name": "stdout",
201
+ "output_type": "stream",
202
+ "text": [
203
+ "{'problem': 'What is the coefficient of $x^2y^6$ in the expansion of $\\\\left(\\\\frac{3}{5}x-\\\\frac{y}{2}\\\\right)^8$? Express your answer as a common fraction.', 'solution': \"To determine the coefficient of \\\\(x^2y^6\\\\) in the expansion of \\\\(\\\\left(\\\\frac{3}{5}x - \\\\frac{y}{2}\\\\right)^8\\\\), we can use the binomial theorem.\\n\\nThe binomial theorem states:\\n\\\\[\\n(a + b)^n = \\\\sum_{k=0}^{n} \\\\binom{n}{k} a^{n-k} b^k\\n\\\\]\\n\\nIn this case, \\\\(a = \\\\frac{3}{5}x\\\\), \\\\(b = -\\\\frac{y}{2}\\\\), and \\\\(n = 8\\\\).\\n\\nWe are interested in the term that contains \\\\(x^2y^6\\\\). In the general term of the binomial expansion:\\n\\\\[\\n\\\\binom{8}{k} \\\\left(\\\\frac{3}{5}x\\\\right)^{8-k} \\\\left(-\\\\frac{y}{2}\\\\right)^k\\n\\\\]\\n\\nTo get \\\\(x^2\\\\), we need \\\\(8 - k = 2\\\\), thus \\\\(k = 6\\\\).\\n\\nSubstituting \\\\(k = 6\\\\) into the expression:\\n\\\\[\\n\\\\binom{8}{6} \\\\left(\\\\frac{3}{5}x\\\\right)^{8-6} \\\\left(-\\\\frac{y}{2}\\\\right)^6 = \\\\binom{8}{6} \\\\left(\\\\frac{3}{5}x\\\\right)^2 \\\\left(-\\\\frac{y}{2}\\\\right)^6\\n\\\\]\\n\\nNow, we will compute each part of this expression.\\n\\n1. Calculate the binomial coefficient \\\\(\\\\binom{8}{6}\\\\).\\n2. Compute \\\\(\\\\left(\\\\frac{3}{5}\\\\right)^2\\\\).\\n3. Compute \\\\(\\\\left(-\\\\frac{y}{2}\\\\right)^6\\\\).\\n4. Combine everything together to get the coefficient of \\\\(x^2y^6\\\\).\\n\\nLet's compute these in Python.\\n```python\\nfrom math import comb\\n\\n# Given values\\nn = 8\\nk = 6\\n\\n# Calculate the binomial coefficient\\nbinom_coeff = comb(n, k)\\n\\n# Compute (3/5)^2\\na_term = (3/5)**2\\n\\n# Compute (-1/2)^6\\nb_term = (-1/2)**6\\n\\n# Combine terms to get the coefficient of x^2y^6\\ncoefficient = binom_coeff * a_term * b_term\\nprint(coefficient)\\n```\\n```output\\n0.1575\\n```\\nThe coefficient of \\\\(x^2y^6\\\\) in the expansion of \\\\(\\\\left(\\\\frac{3}{5}x - \\\\frac{y}{2}\\\\right)^8\\\\) is \\\\(0.1575\\\\). To express this as a common fraction, we recognize that:\\n\\n\\\\[ 0.1575 = \\\\frac{1575}{10000} = \\\\frac{63}{400} \\\\]\\n\\nThus, the coefficient can be expressed as:\\n\\n\\\\[\\n\\\\boxed{\\\\frac{63}{400}}\\n\\\\]\", 'messages': [{'content': 'What is the coefficient of $x^2y^6$ in the expansion of $\\\\left(\\\\frac{3}{5}x-\\\\frac{y}{2}\\\\right)^8$? Express your answer as a common fraction.', 'role': 'user'}, {'content': \"To determine the coefficient of \\\\(x^2y^6\\\\) in the expansion of \\\\(\\\\left(\\\\frac{3}{5}x - \\\\frac{y}{2}\\\\right)^8\\\\), we can use the binomial theorem.\\n\\nThe binomial theorem states:\\n\\\\[\\n(a + b)^n = \\\\sum_{k=0}^{n} \\\\binom{n}{k} a^{n-k} b^k\\n\\\\]\\n\\nIn this case, \\\\(a = \\\\frac{3}{5}x\\\\), \\\\(b = -\\\\frac{y}{2}\\\\), and \\\\(n = 8\\\\).\\n\\nWe are interested in the term that contains \\\\(x^2y^6\\\\). In the general term of the binomial expansion:\\n\\\\[\\n\\\\binom{8}{k} \\\\left(\\\\frac{3}{5}x\\\\right)^{8-k} \\\\left(-\\\\frac{y}{2}\\\\right)^k\\n\\\\]\\n\\nTo get \\\\(x^2\\\\), we need \\\\(8 - k = 2\\\\), thus \\\\(k = 6\\\\).\\n\\nSubstituting \\\\(k = 6\\\\) into the expression:\\n\\\\[\\n\\\\binom{8}{6} \\\\left(\\\\frac{3}{5}x\\\\right)^{8-6} \\\\left(-\\\\frac{y}{2}\\\\right)^6 = \\\\binom{8}{6} \\\\left(\\\\frac{3}{5}x\\\\right)^2 \\\\left(-\\\\frac{y}{2}\\\\right)^6\\n\\\\]\\n\\nNow, we will compute each part of this expression.\\n\\n1. Calculate the binomial coefficient \\\\(\\\\binom{8}{6}\\\\).\\n2. Compute \\\\(\\\\left(\\\\frac{3}{5}\\\\right)^2\\\\).\\n3. Compute \\\\(\\\\left(-\\\\frac{y}{2}\\\\right)^6\\\\).\\n4. Combine everything together to get the coefficient of \\\\(x^2y^6\\\\).\\n\\nLet's compute these in Python.\\n```python\\nfrom math import comb\\n\\n# Given values\\nn = 8\\nk = 6\\n\\n# Calculate the binomial coefficient\\nbinom_coeff = comb(n, k)\\n\\n# Compute (3/5)^2\\na_term = (3/5)**2\\n\\n# Compute (-1/2)^6\\nb_term = (-1/2)**6\\n\\n# Combine terms to get the coefficient of x^2y^6\\ncoefficient = binom_coeff * a_term * b_term\\nprint(coefficient)\\n```\\n```output\\n0.1575\\n```\\nThe coefficient of \\\\(x^2y^6\\\\) in the expansion of \\\\(\\\\left(\\\\frac{3}{5}x - \\\\frac{y}{2}\\\\right)^8\\\\) is \\\\(0.1575\\\\). To express this as a common fraction, we recognize that:\\n\\n\\\\[ 0.1575 = \\\\frac{1575}{10000} = \\\\frac{63}{400} \\\\]\\n\\nThus, the coefficient can be expressed as:\\n\\n\\\\[\\n\\\\boxed{\\\\frac{63}{400}}\\n\\\\]\", 'role': 'assistant'}]}\n"
204
+ ]
205
+ }
206
+ ],
207
+ "source": [
208
+ "print(train_dataset[0])"
209
+ ]
210
+ },
211
+ {
212
+ "cell_type": "markdown",
213
+ "metadata": {
214
+ "id": "DiqBlxK_A0SD"
215
+ },
216
+ "source": [
217
+ "We will adapt our dataset to a conversational format using a custom system prompt, guiding the LLM to generate both step-by-step reasoning and the final answer."
218
+ ]
219
+ },
220
+ {
221
+ "cell_type": "code",
222
+ "execution_count": null,
223
+ "metadata": {
224
+ "id": "RWxK5xFKWbdp"
225
+ },
226
+ "outputs": [],
227
+ "source": [
228
+ "SYSTEM_PROMPT = (\n",
229
+ " \"A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant \"\n",
230
+ " \"first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning \"\n",
231
+ " \"process is enclosed strictly within <think> and </think> tags. \"\n",
232
+ " \"After closing </think>, the assistant MUST provide the final answer in plain text.\"\n",
233
+ ")\n",
234
+ "\n",
235
+ "\n",
236
+ "def make_conversation(example):\n",
237
+ " return {\n",
238
+ " \"prompt\": [\n",
239
+ " {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n",
240
+ " {\"role\": \"user\", \"content\": example[\"problem\"]},\n",
241
+ " ],\n",
242
+ " }\n",
243
+ "\n",
244
+ "train_dataset = train_dataset.map(make_conversation)"
245
+ ]
246
+ },
247
+ {
248
+ "cell_type": "markdown",
249
+ "metadata": {
250
+ "id": "sND566XAC0kD"
251
+ },
252
+ "source": [
253
+ "Let's take a look at an example:"
254
+ ]
255
+ },
256
+ {
257
+ "cell_type": "code",
258
+ "execution_count": null,
259
+ "metadata": {
260
+ "id": "Q-kHUmpMWbdp",
261
+ "outputId": "452beb3a-1091-46d4-997e-04b91562d66c"
262
+ },
263
+ "outputs": [
264
+ {
265
+ "name": "stdout",
266
+ "output_type": "stream",
267
+ "text": [
268
+ "[{'content': 'A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process is enclosed strictly within <think> and </think> tags. After closing </think>, the assistant MUST provide the final answer in plain text.', 'role': 'system'}, {'content': 'What is the coefficient of $x^2y^6$ in the expansion of $\\\\left(\\\\frac{3}{5}x-\\\\frac{y}{2}\\\\right)^8$? Express your answer as a common fraction.', 'role': 'user'}]\n"
269
+ ]
270
+ }
271
+ ],
272
+ "source": [
273
+ "print(train_dataset[0]['prompt'])"
274
+ ]
275
+ },
276
+ {
277
+ "cell_type": "markdown",
278
+ "metadata": {
279
+ "id": "bw0qcp-CC3G0"
280
+ },
281
+ "source": [
282
+ "We'll remove the `messages` and `problem` columns, as we only need the custom `prompt` column and `solution` to verify the generated answer."
283
+ ]
284
+ },
285
+ {
286
+ "cell_type": "code",
287
+ "execution_count": null,
288
+ "metadata": {
289
+ "id": "SzbF3hdRWbdp",
290
+ "outputId": "bd59a383-1d4e-4020-c232-79ce66073fd1"
291
+ },
292
+ "outputs": [
293
+ {
294
+ "name": "stdout",
295
+ "output_type": "stream",
296
+ "text": [
297
+ "Dataset({\n",
298
+ " features: ['solution', 'prompt'],\n",
299
+ " num_rows: 3622\n",
300
+ "})\n"
301
+ ]
302
+ }
303
+ ],
304
+ "source": [
305
+ "train_dataset = train_dataset.remove_columns(['messages', 'problem'])\n",
306
+ "print(train_dataset)"
307
+ ]
308
+ },
309
+ {
310
+ "cell_type": "markdown",
311
+ "metadata": {
312
+ "id": "tvs5rjQBr7af"
313
+ },
314
+ "source": [
315
+ "## Load model and configure LoRA/QLoRA\n",
316
+ "\n",
317
+ "Below, choose your **preferred model**. All of the options have been tested on **free Colab instances**.\n",
318
+ "\n",
319
+ "> 💡 Note: Some models, such as Qwen2.5 and Qwen3, are known to have been pretrained on data that improves their math performance. Be cautious when selecting the appropriate model for training to ensure meaningful fine-tuning results ([source](https://thinkingmachines.ai/blog/lora/))."
320
+ ]
321
+ },
322
+ {
323
+ "cell_type": "code",
324
+ "execution_count": null,
325
+ "metadata": {
326
+ "id": "7_uaW3JfWbdp"
327
+ },
328
+ "outputs": [],
329
+ "source": [
330
+ "# Select one model below by uncommenting the line you want to use 👇\n",
331
+ "## Qwen\n",
332
+ "model_id, output_dir = \"Qwen/Qwen2-7B-Instruct\", \"t4-Qwen2-7B-Instruct-GRPO\" # ✅ ~9.2GB VRAM\n",
333
+ "# model_id, output_dir = \"unsloth/qwen3-14b-unsloth-bnb-4bit\", \"qwen3-14b-unsloth-bnb-4bit-GRPO\" # ⚠️ OOM with this config; fits if GRPO params are reduced\n",
334
+ "# model_id, output_dir = \"Qwen/Qwen3-8B\", \"Qwen3-8B-GRPO\" # ✅ ~9.9GB VRAM\n",
335
+ "# model_id, output_dir = \"Qwen/Qwen2.5-7B-Instruct\", \"Qwen2.5-7B-Instruct-GRPO\" # ✅ ~9.2GB VRAM\n",
336
+ "\n",
337
+ "## Llama\n",
338
+ "# model_id, output_dir = \"meta-llama/Llama-3.2-3B-Instruct\", \"Llama-3.2-3B-Instruct-GRPO\" # ✅ ~5.7GB VRAM\n",
339
+ "# model_id, output_dir = \"meta-llama/Llama-3.1-8B-Instruct\", \"Llama-3.1-8B-Instruct-GRPO\" # ✅ ~9.5GB VRAM\n",
340
+ "\n",
341
+ "## LFM2.5\n",
342
+ "# model_id, output_dir = \"LiquidAI/LFM2.5-1.2B-Instruct\", \"LFM2.5-1.2B-Instruct-GRPO\" # ✅ ~1.12 GB VRAM"
343
+ ]
344
+ },
345
+ {
346
+ "cell_type": "markdown",
347
+ "metadata": {
348
+ "id": "aw__94OWDnER"
349
+ },
350
+ "source": [
351
+ "This notebook can be used with two fine-tuning methods. By default, it is set up for **QLoRA**, which includes quantization using `BitsAndBytesConfig`. If you prefer to use standard **LoRA** without quantization, simply comment out the `BitsAndBytesConfig` configuration (training without quantization consumes more memory).\n",
352
+ "\n",
353
+ "Let's load the selected model using `transformers`, configuring QLoRA via `bitsandbytes` (you can remove it if doing LoRA). We don't need to configure the tokenizer since the trainer takes care of that automatically."
354
+ ]
355
+ },
356
+ {
357
+ "cell_type": "code",
358
+ "execution_count": null,
359
+ "metadata": {
360
+ "colab": {
361
+ "referenced_widgets": [
362
+ "1130e5a744864ca5b5873731e4764983"
363
+ ]
364
+ },
365
+ "id": "o86TnTchWbdp",
366
+ "outputId": "77a7e6c8-0360-40f1-eea7-b941be031366"
367
+ },
368
+ "outputs": [
369
+ {
370
+ "data": {
371
+ "application/vnd.jupyter.widget-view+json": {
372
+ "model_id": "1130e5a744864ca5b5873731e4764983",
373
+ "version_major": 2,
374
+ "version_minor": 0
375
+ },
376
+ "text/plain": [
377
+ "Loading checkpoint shards: 0%| | 0/4 [00:00<?, ?it/s]"
378
+ ]
379
+ },
380
+ "metadata": {},
381
+ "output_type": "display_data"
382
+ }
383
+ ],
384
+ "source": [
385
+ "import torch\n",
386
+ "from transformers import AutoModelForCausalLM, BitsAndBytesConfig\n",
387
+ "\n",
388
+ "model = AutoModelForCausalLM.from_pretrained(\n",
389
+ " model_id,\n",
390
+ " attn_implementation=\"sdpa\", # Change to Flash Attention if GPU has support\n",
391
+ " dtype=\"float32\", # Change to bfloat16 if GPU has support\n",
392
+ " quantization_config=BitsAndBytesConfig(\n",
393
+ " load_in_4bit=True, # Load the model in 4-bit precision to save memory\n",
394
+ " bnb_4bit_compute_dtype=torch.float16, # Data type used for internal computations in quantization\n",
395
+ " bnb_4bit_use_double_quant=True, # Use double quantization to improve accuracy\n",
396
+ " bnb_4bit_quant_type=\"nf4\" # Type of quantization. \"nf4\" is recommended for recent LLMs\n",
397
+ " )\n",
398
+ ")"
399
+ ]
400
+ },
401
+ {
402
+ "cell_type": "markdown",
403
+ "metadata": {
404
+ "id": "AM-G0_QmDyZC"
405
+ },
406
+ "source": [
407
+ "The following cell defines LoRA (or QLoRA if needed). When training with LoRA/QLoRA, we use a **base model** (the one selected above) and, instead of modifying its original weights, we fine-tune a **LoRA adapter**, a lightweight layer that enables efficient and memory-friendly training. The **`target_modules`** specify which parts of the model (e.g., attention or projection layers) will be adapted by LoRA during fine-tuning."
408
+ ]
409
+ },
410
+ {
411
+ "cell_type": "code",
412
+ "execution_count": null,
413
+ "metadata": {
414
+ "id": "WIz2pmX6Wbdp"
415
+ },
416
+ "outputs": [],
417
+ "source": [
418
+ "from peft import LoraConfig\n",
419
+ "\n",
420
+ "# You may need to update `target_modules` depending on the architecture of your chosen model.\n",
421
+ "# For example, different LLMs might have different attention/projection layer names.\n",
422
+ "peft_config = LoraConfig(\n",
423
+ " r=32,\n",
424
+ " lora_alpha=32,\n",
425
+ " target_modules = [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\",],\n",
426
+ ")"
427
+ ]
428
+ },
429
+ {
430
+ "cell_type": "markdown",
431
+ "metadata": {
432
+ "id": "prKnAp-Esyiq"
433
+ },
434
+ "source": [
435
+ "## Train model\n",
436
+ "\n",
437
+ "GRPO requires **reward functions** to guide the learning process. For convenience, we can directly load pre-defined rewards from `trl.rewards`, which already includes a [collection of ready-to-use rewards](https://huggingface.co/docs/trl/rewards).\n",
438
+ "\n",
439
+ "If you want to create your own custom reward functions to teach the model, a reward function is simply a Python function that takes the generated completions and returns a list of floats. For example, the following function, which we use in this notebook, rewards completions that correctly follow the `<think>` format:\n",
440
+ "\n",
441
+ "```python\n",
442
+ "def think_format_reward(completions: list[list[dict[str, str]]], **kwargs) -> list[float]:\n",
443
+ " pattern = r\"^<think>(?!.*<think>)(.*?)</think>.*$\"\n",
444
+ " completion_contents = [completion[0][\"content\"] for completion in completions]\n",
445
+ " matches = [re.match(pattern, content, re.DOTALL | re.MULTILINE) for content in completion_contents]\n",
446
+ " return [1.0 if match else 0.0 for match in matches]\n",
447
+ "```\n",
448
+ "\n",
449
+ "In this notebook, we will use both `think_format_reward`, which rewards completions that correctly follow the `<think>` format, and `reasoning_accuracy_reward`, which evaluates the correctness of the model's solution to the mathematical problem. Together, these rewards guide the model to generate **structured reasoning** while producing **accurate answers**."
450
+ ]
451
+ },
452
+ {
453
+ "cell_type": "code",
454
+ "execution_count": null,
455
+ "metadata": {
456
+ "id": "lj42Qs5vWbdp"
457
+ },
458
+ "outputs": [],
459
+ "source": [
460
+ "from trl.rewards import think_format_reward, reasoning_accuracy_reward"
461
+ ]
462
+ },
463
+ {
464
+ "cell_type": "markdown",
465
+ "metadata": {
466
+ "id": "bFgYgxMbtbEZ"
467
+ },
468
+ "source": [
469
+ "We'll configure **GRPO** using `GRPOConfig`, keeping the parameters minimal so that the training can run on a free Colab instance. You can adjust these settings if you have access to more resources. For a complete list of available parameters and their descriptions, refer to the [TRL GRPOConfig documentation](https://huggingface.co/docs/trl/grpo_trainer#trl.GRPOConfig).\n",
470
+ "\n",
471
+ "> 💡 Note: TRL supports using **vLLM** for generation during GRPO training, which can significantly speed up training. However, it increases VRAM usage since a separate vLLM process is active to handle generation. In this notebook, we do not enable vLLM because we are using **QLoRA**, which updates the quantized vLLM model weights at every step. Enabling vLLM in this setup can cause weight precision issues and make convergence more challenging. The configuration includes the vLLM parameters in case you want to experiment with it. Learn more about vLLM integration in TRL [here](https://huggingface.co/docs/trl/main/en/vllm_integration)."
472
+ ]
473
+ },
474
+ {
475
+ "cell_type": "code",
476
+ "execution_count": null,
477
+ "metadata": {
478
+ "id": "JY11EQMhWbdp"
479
+ },
480
+ "outputs": [],
481
+ "source": [
482
+ "from trl import GRPOConfig\n",
483
+ "\n",
484
+ "# Configure training arguments using GRPOConfig\n",
485
+ "training_args = GRPOConfig(\n",
486
+ " # Training schedule / optimization\n",
487
+ " learning_rate=2e-5, # Learning rate for the optimizer\n",
488
+ " #num_train_epochs=1,\n",
489
+ " max_steps=500, # Number of dataset passes. For full trainings, use `num_train_epochs` instead\n",
490
+ "\n",
491
+ " # Parameters that control GRPO training (you can adapt them)\n",
492
+ " per_device_train_batch_size = 8,\n",
493
+ " max_completion_length=256, # default: 256 # Max completion length produced during training\n",
494
+ " num_generations=8, # default: 8 # Number of generations produced during trainig for comparison\n",
495
+ "\n",
496
+ " # Optimizations\n",
497
+ " optim = \"paged_adamw_8bit\", # Optimizer\n",
498
+ " use_liger_kernel=True, # Enable Liger kernel optimizations for faster training\n",
499
+ "\n",
500
+ " # Parameters related to reporting and saving\n",
501
+ " output_dir=output_dir, # Where to save model checkpoints and logs\n",
502
+ " logging_steps=10, # Log training metrics every N steps\n",
503
+ " report_to=\"trackio\", # Experiment tracking tool\n",
504
+ " trackio_space_id=output_dir, # HF Space where the experiment tracking will be saved\n",
505
+ " log_completions=False, # Return model completions during training\n",
506
+ "\n",
507
+ " # Hub integration\n",
508
+ " push_to_hub=True, # Automatically push the trained model to the Hugging Face Hub\n",
509
+ " # The model will be saved under your Hub account in the repository named `output_dir`\n",
510
+ " # vLLM params\n",
511
+ " #use_vllm=False, # Activate vLLM training for faster training\n",
512
+ " #vllm_mode='colocate',\n",
513
+ " #vllm_gpu_memory_utilization=0.1,\n",
514
+ " #vllm_enable_sleep_mode=True\n",
515
+ ")"
516
+ ]
517
+ },
518
+ {
519
+ "cell_type": "markdown",
520
+ "metadata": {
521
+ "id": "-9LlOAvWFSor"
522
+ },
523
+ "source": [
524
+ "Configure the `GRPOTrainer` by passing the previously defined `training_args`. To keep memory usage low, we are not using an evaluation dataset, but you can include one if desired. We also provide the reward functions that were imported earlier to guide the training process."
525
+ ]
526
+ },
527
+ {
528
+ "cell_type": "code",
529
+ "execution_count": null,
530
+ "metadata": {
531
+ "id": "iI_E9KCUWbdq"
532
+ },
533
+ "outputs": [],
534
+ "source": [
535
+ "from trl import GRPOTrainer\n",
536
+ "\n",
537
+ "trainer = GRPOTrainer(\n",
538
+ " model=model,\n",
539
+ " reward_funcs=[think_format_reward, reasoning_accuracy_reward],\n",
540
+ " args=training_args,\n",
541
+ " train_dataset=train_dataset,\n",
542
+ " peft_config=peft_config,\n",
543
+ ")"
544
+ ]
545
+ },
546
+ {
547
+ "cell_type": "markdown",
548
+ "metadata": {
549
+ "id": "8dY7bK8FGLhh"
550
+ },
551
+ "source": [
552
+ "Show memory stats before training"
553
+ ]
554
+ },
555
+ {
556
+ "cell_type": "code",
557
+ "execution_count": null,
558
+ "metadata": {
559
+ "id": "PEVRGlrAWbdq",
560
+ "outputId": "78fac9e4-4ae6-4836-bd10-c30b39059782"
561
+ },
562
+ "outputs": [
563
+ {
564
+ "name": "stdout",
565
+ "output_type": "stream",
566
+ "text": [
567
+ "GPU = Tesla T4. Max memory = 14.741 GB.\n",
568
+ "6.773 GB of memory reserved.\n"
569
+ ]
570
+ }
571
+ ],
572
+ "source": [
573
+ "gpu_stats = torch.cuda.get_device_properties(0)\n",
574
+ "start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n",
575
+ "max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)\n",
576
+ "\n",
577
+ "print(f\"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.\")\n",
578
+ "print(f\"{start_gpu_memory} GB of memory reserved.\")"
579
+ ]
580
+ },
581
+ {
582
+ "cell_type": "markdown",
583
+ "metadata": {
584
+ "id": "z-5xPtfIGQL5"
585
+ },
586
+ "source": [
587
+ "And train!"
588
+ ]
589
+ },
590
+ {
591
+ "cell_type": "markdown",
592
+ "metadata": {},
593
+ "source": [
594
+ "Training on a T4 in Colab with the configuration defined in this notebook takes around 13 hours. If you're just experimenting, you can try the following quicker task ([source](https://huggingface.co/learn/llm-course/en/chapter12/5)):\n",
595
+ "\n",
596
+ "```python\n",
597
+ "dataset = load_dataset(\"mlabonne/smoltldr\")\n",
598
+ "\n",
599
+ "# Reward function\n",
600
+ "ideal_length = 50\n",
601
+ "\n",
602
+ "def reward_len(completions, **kwargs):\n",
603
+ " return [-abs(ideal_length - len(completion)) for completion in completions]\n",
604
+ "```"
605
+ ]
606
+ },
607
+ {
608
+ "cell_type": "code",
609
+ "execution_count": null,
610
+ "metadata": {
611
+ "id": "zl7-PmoXWbdq",
612
+ "outputId": "f39c8c3c-43c2-4f2d-c98d-4c595ae1129f"
613
+ },
614
+ "outputs": [
615
+ {
616
+ "name": "stderr",
617
+ "output_type": "stream",
618
+ "text": [
619
+ "The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'bos_token_id': None, 'pad_token_id': 151643}.\n"
620
+ ]
621
+ },
622
+ {
623
+ "name": "stdout",
624
+ "output_type": "stream",
625
+ "text": [
626
+ "* Trackio project initialized: huggingface\n",
627
+ "* Trackio metrics will be synced to Hugging Face Dataset: sergiopaniego/t4-Qwen2-7B-Instruct-GRPO-dataset\n",
628
+ "* Creating new space: https://huggingface.co/spaces/sergiopaniego/t4-Qwen2-7B-Instruct-GRPO\n",
629
+ "* View dashboard by going to: https://sergiopaniego-t4-Qwen2-7B-Instruct-GRPO.hf.space/\n"
630
+ ]
631
+ },
632
+ {
633
+ "data": {
634
+ "text/html": [
635
+ "<div><iframe src=\"https://sergiopaniego-t4-Qwen2-7B-Instruct-GRPO.hf.space/\" width=\"100%\" height=\"1000px\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
636
+ ],
637
+ "text/plain": [
638
+ "<IPython.core.display.HTML object>"
639
+ ]
640
+ },
641
+ "metadata": {},
642
+ "output_type": "display_data"
643
+ },
644
+ {
645
+ "name": "stdout",
646
+ "output_type": "stream",
647
+ "text": [
648
+ "* Created new run: sergiopaniego-1766143600\n"
649
+ ]
650
+ },
651
+ {
652
+ "data": {
653
+ "text/html": [
654
+ "\n",
655
+ " <div>\n",
656
+ " \n",
657
+ " <progress value='500' max='500' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
658
+ " [500/500 13:05:04, Epoch 0/1]\n",
659
+ " </div>\n",
660
+ " <table border=\"1\" class=\"dataframe\">\n",
661
+ " <thead>\n",
662
+ " <tr style=\"text-align: left;\">\n",
663
+ " <th>Step</th>\n",
664
+ " <th>Training Loss</th>\n",
665
+ " </tr>\n",
666
+ " </thead>\n",
667
+ " <tbody>\n",
668
+ " <tr>\n",
669
+ " <td>10</td>\n",
670
+ " <td>0.027900</td>\n",
671
+ " </tr>\n",
672
+ " <tr>\n",
673
+ " <td>20</td>\n",
674
+ " <td>-0.011600</td>\n",
675
+ " </tr>\n",
676
+ " <tr>\n",
677
+ " <td>30</td>\n",
678
+ " <td>0.021500</td>\n",
679
+ " </tr>\n",
680
+ " <tr>\n",
681
+ " <td>40</td>\n",
682
+ " <td>0.033400</td>\n",
683
+ " </tr>\n",
684
+ " <tr>\n",
685
+ " <td>50</td>\n",
686
+ " <td>0.039400</td>\n",
687
+ " </tr>\n",
688
+ " <tr>\n",
689
+ " <td>60</td>\n",
690
+ " <td>0.010300</td>\n",
691
+ " </tr>\n",
692
+ " <tr>\n",
693
+ " <td>70</td>\n",
694
+ " <td>0.048200</td>\n",
695
+ " </tr>\n",
696
+ " <tr>\n",
697
+ " <td>80</td>\n",
698
+ " <td>0.067300</td>\n",
699
+ " </tr>\n",
700
+ " <tr>\n",
701
+ " <td>90</td>\n",
702
+ " <td>0.030600</td>\n",
703
+ " </tr>\n",
704
+ " <tr>\n",
705
+ " <td>100</td>\n",
706
+ " <td>0.064000</td>\n",
707
+ " </tr>\n",
708
+ " <tr>\n",
709
+ " <td>110</td>\n",
710
+ " <td>0.021500</td>\n",
711
+ " </tr>\n",
712
+ " <tr>\n",
713
+ " <td>120</td>\n",
714
+ " <td>0.021400</td>\n",
715
+ " </tr>\n",
716
+ " <tr>\n",
717
+ " <td>130</td>\n",
718
+ " <td>0.000000</td>\n",
719
+ " </tr>\n",
720
+ " <tr>\n",
721
+ " <td>140</td>\n",
722
+ " <td>-0.028500</td>\n",
723
+ " </tr>\n",
724
+ " <tr>\n",
725
+ " <td>150</td>\n",
726
+ " <td>-0.003100</td>\n",
727
+ " </tr>\n",
728
+ " <tr>\n",
729
+ " <td>160</td>\n",
730
+ " <td>0.017300</td>\n",
731
+ " </tr>\n",
732
+ " <tr>\n",
733
+ " <td>170</td>\n",
734
+ " <td>-0.024700</td>\n",
735
+ " </tr>\n",
736
+ " <tr>\n",
737
+ " <td>180</td>\n",
738
+ " <td>0.003300</td>\n",
739
+ " </tr>\n",
740
+ " <tr>\n",
741
+ " <td>190</td>\n",
742
+ " <td>0.000000</td>\n",
743
+ " </tr>\n",
744
+ " <tr>\n",
745
+ " <td>200</td>\n",
746
+ " <td>-0.001400</td>\n",
747
+ " </tr>\n",
748
+ " <tr>\n",
749
+ " <td>210</td>\n",
750
+ " <td>0.008000</td>\n",
751
+ " </tr>\n",
752
+ " <tr>\n",
753
+ " <td>220</td>\n",
754
+ " <td>0.034300</td>\n",
755
+ " </tr>\n",
756
+ " <tr>\n",
757
+ " <td>230</td>\n",
758
+ " <td>0.044600</td>\n",
759
+ " </tr>\n",
760
+ " <tr>\n",
761
+ " <td>240</td>\n",
762
+ " <td>0.016400</td>\n",
763
+ " </tr>\n",
764
+ " <tr>\n",
765
+ " <td>250</td>\n",
766
+ " <td>-0.015200</td>\n",
767
+ " </tr>\n",
768
+ " <tr>\n",
769
+ " <td>260</td>\n",
770
+ " <td>0.016800</td>\n",
771
+ " </tr>\n",
772
+ " <tr>\n",
773
+ " <td>270</td>\n",
774
+ " <td>0.042900</td>\n",
775
+ " </tr>\n",
776
+ " <tr>\n",
777
+ " <td>280</td>\n",
778
+ " <td>0.031300</td>\n",
779
+ " </tr>\n",
780
+ " <tr>\n",
781
+ " <td>290</td>\n",
782
+ " <td>0.006200</td>\n",
783
+ " </tr>\n",
784
+ " <tr>\n",
785
+ " <td>300</td>\n",
786
+ " <td>0.043300</td>\n",
787
+ " </tr>\n",
788
+ " <tr>\n",
789
+ " <td>310</td>\n",
790
+ " <td>0.029700</td>\n",
791
+ " </tr>\n",
792
+ " <tr>\n",
793
+ " <td>320</td>\n",
794
+ " <td>0.001100</td>\n",
795
+ " </tr>\n",
796
+ " <tr>\n",
797
+ " <td>330</td>\n",
798
+ " <td>0.027000</td>\n",
799
+ " </tr>\n",
800
+ " <tr>\n",
801
+ " <td>340</td>\n",
802
+ " <td>-0.006700</td>\n",
803
+ " </tr>\n",
804
+ " <tr>\n",
805
+ " <td>350</td>\n",
806
+ " <td>0.027200</td>\n",
807
+ " </tr>\n",
808
+ " <tr>\n",
809
+ " <td>360</td>\n",
810
+ " <td>0.008200</td>\n",
811
+ " </tr>\n",
812
+ " <tr>\n",
813
+ " <td>370</td>\n",
814
+ " <td>-0.015800</td>\n",
815
+ " </tr>\n",
816
+ " <tr>\n",
817
+ " <td>380</td>\n",
818
+ " <td>0.007200</td>\n",
819
+ " </tr>\n",
820
+ " <tr>\n",
821
+ " <td>390</td>\n",
822
+ " <td>0.012100</td>\n",
823
+ " </tr>\n",
824
+ " <tr>\n",
825
+ " <td>400</td>\n",
826
+ " <td>0.000000</td>\n",
827
+ " </tr>\n",
828
+ " <tr>\n",
829
+ " <td>410</td>\n",
830
+ " <td>0.010500</td>\n",
831
+ " </tr>\n",
832
+ " <tr>\n",
833
+ " <td>420</td>\n",
834
+ " <td>0.019800</td>\n",
835
+ " </tr>\n",
836
+ " <tr>\n",
837
+ " <td>430</td>\n",
838
+ " <td>0.000800</td>\n",
839
+ " </tr>\n",
840
+ " <tr>\n",
841
+ " <td>440</td>\n",
842
+ " <td>0.003400</td>\n",
843
+ " </tr>\n",
844
+ " <tr>\n",
845
+ " <td>450</td>\n",
846
+ " <td>-0.007900</td>\n",
847
+ " </tr>\n",
848
+ " <tr>\n",
849
+ " <td>460</td>\n",
850
+ " <td>-0.011800</td>\n",
851
+ " </tr>\n",
852
+ " <tr>\n",
853
+ " <td>470</td>\n",
854
+ " <td>-0.016300</td>\n",
855
+ " </tr>\n",
856
+ " <tr>\n",
857
+ " <td>480</td>\n",
858
+ " <td>-0.002300</td>\n",
859
+ " </tr>\n",
860
+ " <tr>\n",
861
+ " <td>490</td>\n",
862
+ " <td>-0.005500</td>\n",
863
+ " </tr>\n",
864
+ " <tr>\n",
865
+ " <td>500</td>\n",
866
+ " <td>0.038000</td>\n",
867
+ " </tr>\n",
868
+ " </tbody>\n",
869
+ "</table><p>"
870
+ ],
871
+ "text/plain": [
872
+ "<IPython.core.display.HTML object>"
873
+ ]
874
+ },
875
+ "metadata": {},
876
+ "output_type": "display_data"
877
+ },
878
+ {
879
+ "name": "stdout",
880
+ "output_type": "stream",
881
+ "text": [
882
+ "* Run finished. Uploading logs to Trackio (please wait...)\n"
883
+ ]
884
+ }
885
+ ],
886
+ "source": [
887
+ "trainer_stats = trainer.train()"
888
+ ]
889
+ },
890
+ {
891
+ "cell_type": "markdown",
892
+ "metadata": {
893
+ "id": "iqAN-XLCGTGW"
894
+ },
895
+ "source": [
896
+ "Show memory stats after training"
897
+ ]
898
+ },
899
+ {
900
+ "cell_type": "code",
901
+ "execution_count": null,
902
+ "metadata": {
903
+ "id": "4BeEwp5EWbds",
904
+ "outputId": "668b8a2c-2eef-4e34-8d4a-2a43ccbbdc00"
905
+ },
906
+ "outputs": [
907
+ {
908
+ "name": "stdout",
909
+ "output_type": "stream",
910
+ "text": [
911
+ "47228.679 seconds used for training.\n",
912
+ "787.14 minutes used for training.\n",
913
+ "Peak reserved memory = 8.832 GB.\n",
914
+ "Peak reserved memory for training = 2.059 GB.\n",
915
+ "Peak reserved memory % of max memory = 59.915 %.\n",
916
+ "Peak reserved memory for training % of max memory = 13.968 %.\n"
917
+ ]
918
+ }
919
+ ],
920
+ "source": [
921
+ "used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n",
922
+ "used_memory_for_lora = round(used_memory - start_gpu_memory, 3)\n",
923
+ "used_percentage = round(used_memory / max_memory * 100, 3)\n",
924
+ "lora_percentage = round(used_memory_for_lora / max_memory * 100, 3)\n",
925
+ "\n",
926
+ "print(f\"{trainer_stats.metrics['train_runtime']} seconds used for training.\")\n",
927
+ "print(f\"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training.\")\n",
928
+ "print(f\"Peak reserved memory = {used_memory} GB.\")\n",
929
+ "print(f\"Peak reserved memory for training = {used_memory_for_lora} GB.\")\n",
930
+ "print(f\"Peak reserved memory % of max memory = {used_percentage} %.\")\n",
931
+ "print(f\"Peak reserved memory for training % of max memory = {lora_percentage} %.\")"
932
+ ]
933
+ },
934
+ {
935
+ "cell_type": "markdown",
936
+ "metadata": {
937
+ "id": "R8Sd_AqILeYi"
938
+ },
939
+ "source": [
940
+ "The training procedure generates both standard training logs and **trackio** logs, which help us monitor the training progress. Example outputs would look like the following:"
941
+ ]
942
+ },
943
+ {
944
+ "cell_type": "markdown",
945
+ "metadata": {
946
+ "id": "2bPn6gruLf-n"
947
+ },
948
+ "source": [
949
+ "<img src=\"https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/grpo-qlora-notebook-trackio.png\" width=\"50%\">"
950
+ ]
951
+ },
952
+ {
953
+ "cell_type": "markdown",
954
+ "metadata": {
955
+ "id": "ibO4f7tuLboQ"
956
+ },
957
+ "source": [
958
+ "## Saving fine tuned model\n",
959
+ "\n",
960
+ "In this step, we save the fine-tuned model both **locally** and to the **Hugging Face Hub** using the credentials from your account."
961
+ ]
962
+ },
963
+ {
964
+ "cell_type": "code",
965
+ "execution_count": null,
966
+ "metadata": {
967
+ "colab": {
968
+ "referenced_widgets": [
969
+ "e6a3677667ce47bcba55e3e950e446f9",
970
+ "17adb84604d84cf688a89a21f6cc6150",
971
+ "a21c1bbd3cd04738a8c96fbfc0c016c6",
972
+ "65cadde3da7642188f029bb2aceaa7c6",
973
+ "0404b89e5ce24e76958c72bedc1a95cc",
974
+ "c52baf990fde40c0873747e827dc6926",
975
+ "191653e8ce184123a68f26fbf2b78745",
976
+ "0bb882d400864b249c80132264de2623",
977
+ "09cbfcf6e51c431798f4e392a81be6d3",
978
+ "d6521f73f23f42e18ee462a547f251a1"
979
+ ]
980
+ },
981
+ "id": "itpVDjy0Wbdt",
982
+ "outputId": "b821c7ed-6c9d-440a-a797-e25291627bef"
983
+ },
984
+ "outputs": [],
985
+ "source": [
986
+ "trainer.save_model(output_dir)\n",
987
+ "trainer.push_to_hub(dataset_name=dataset_name)"
988
+ ]
989
+ },
990
+ {
991
+ "cell_type": "markdown",
992
+ "metadata": {
993
+ "id": "81eBZe-X7daz"
994
+ },
995
+ "source": [
996
+ "## Load the fine-tuned model and run inference\n",
997
+ "\n",
998
+ "Now, let's test our fine-tuned model by loading the **LoRA/QLoRA adapter** and performing **inference**. We'll start by loading the **base model**, then attach the adapter to it, creating the final fine-tuned model ready for evaluation."
999
+ ]
1000
+ },
1001
+ {
1002
+ "cell_type": "code",
1003
+ "execution_count": null,
1004
+ "metadata": {
1005
+ "colab": {
1006
+ "referenced_widgets": [
1007
+ "1d3fbf86d53845beac599c5b231e87ea"
1008
+ ]
1009
+ },
1010
+ "id": "ZLdaWYzNWbdt",
1011
+ "outputId": "a103b64b-1f6b-4423-c5fd-402f210e6dc3"
1012
+ },
1013
+ "outputs": [
1014
+ {
1015
+ "data": {
1016
+ "application/vnd.jupyter.widget-view+json": {
1017
+ "model_id": "1d3fbf86d53845beac599c5b231e87ea",
1018
+ "version_major": 2,
1019
+ "version_minor": 0
1020
+ },
1021
+ "text/plain": [
1022
+ "Loading checkpoint shards: 0%| | 0/4 [00:00<?, ?it/s]"
1023
+ ]
1024
+ },
1025
+ "metadata": {},
1026
+ "output_type": "display_data"
1027
+ }
1028
+ ],
1029
+ "source": [
1030
+ "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
1031
+ "from peft import PeftModel\n",
1032
+ "\n",
1033
+ "adapter_model = f\"sergiopaniego/{output_dir}\" # Replace with your HF username or organization\n",
1034
+ "\n",
1035
+ "base_model = AutoModelForCausalLM.from_pretrained(model_id, dtype=\"auto\", device_map=\"auto\")\n",
1036
+ "\n",
1037
+ "tokenizer = AutoTokenizer.from_pretrained(model_id)"
1038
+ ]
1039
+ },
1040
+ {
1041
+ "cell_type": "markdown",
1042
+ "metadata": {
1043
+ "id": "JvwM6ym-7nnt"
1044
+ },
1045
+ "source": [
1046
+ "Let's test with one example from the test set of the dataset"
1047
+ ]
1048
+ },
1049
+ {
1050
+ "cell_type": "code",
1051
+ "execution_count": null,
1052
+ "metadata": {
1053
+ "colab": {
1054
+ "referenced_widgets": [
1055
+ "74ca3f7b365640ba883a9a236700517e"
1056
+ ]
1057
+ },
1058
+ "id": "XjpojLV-Wbdt",
1059
+ "outputId": "bcc039de-72ae-4713-a1fb-c006163999e7"
1060
+ },
1061
+ "outputs": [
1062
+ {
1063
+ "data": {
1064
+ "application/vnd.jupyter.widget-view+json": {
1065
+ "model_id": "74ca3f7b365640ba883a9a236700517e",
1066
+ "version_major": 2,
1067
+ "version_minor": 0
1068
+ },
1069
+ "text/plain": [
1070
+ "Map: 0%| | 0/1 [00:00<?, ? examples/s]"
1071
+ ]
1072
+ },
1073
+ "metadata": {},
1074
+ "output_type": "display_data"
1075
+ },
1076
+ {
1077
+ "data": {
1078
+ "text/plain": [
1079
+ "[{'content': 'A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process is enclosed strictly within <think> and </think> tags. After closing </think>, the assistant MUST provide the final answer in plain text.',\n",
1080
+ " 'role': 'system'},\n",
1081
+ " {'content': \"In 1988, a person's age was equal to the sum of the digits of their birth year. How old was this person?\",\n",
1082
+ " 'role': 'user'}]"
1083
+ ]
1084
+ },
1085
+ "execution_count": 5,
1086
+ "metadata": {},
1087
+ "output_type": "execute_result"
1088
+ }
1089
+ ],
1090
+ "source": [
1091
+ "from datasets import load_dataset\n",
1092
+ "\n",
1093
+ "dataset_name = 'AI-MO/NuminaMath-TIR'\n",
1094
+ "test_dataset = load_dataset(dataset_name, split='test[:1%]')\n",
1095
+ "test_dataset = test_dataset.map(make_conversation)\n",
1096
+ "test_dataset = test_dataset.remove_columns(['messages', 'problem'])\n",
1097
+ "test_dataset[0]['prompt']"
1098
+ ]
1099
+ },
1100
+ {
1101
+ "cell_type": "markdown",
1102
+ "metadata": {
1103
+ "id": "CxKyZwG28BYJ"
1104
+ },
1105
+ "source": [
1106
+ "Let's first check what's the output for the base model, without the adapter."
1107
+ ]
1108
+ },
1109
+ {
1110
+ "cell_type": "code",
1111
+ "execution_count": null,
1112
+ "metadata": {
1113
+ "id": "qTPJY96eWbdt",
1114
+ "outputId": "ed02acca-e856-44ec-fa20-c32efd81e018"
1115
+ },
1116
+ "outputs": [
1117
+ {
1118
+ "name": "stdout",
1119
+ "output_type": "stream",
1120
+ "text": [
1121
+ "To solve this problem, let's denote the birth year of the person as \\(Y\\) (where \\(Y\\) is a four-digit number) and their age in 1988 as \\(A\\). According to the given condition, their age in 1988 is equal to the sum of the digits of their birth year. \n",
1122
+ "\n",
1123
+ "Since we're looking at the year 1988, the person would be \\(1988 - Y\\) years old in that year. Given the condition:\n",
1124
+ "\n",
1125
+ "\\[1988 - Y = \\text{sum of the digits of } Y\\]\n",
1126
+ "\n",
1127
+ "Let's break down the possible range for \\(Y\\). Since the person's age must be less than or equal to 100 (as the sum of the digits of any four-digit number cannot exceed 36), \\(Y\\) must be between 1989 and 2088.\n",
1128
+ "\n",
1129
+ "We can systematically check each year in this range to find when the condition holds true. However, considering the constraint on age, we can narrow our search significantly. For example, if \\(Y\\) were 1990, the sum of its digits would be 18, which is not a reasonable age. We need\n"
1130
+ ]
1131
+ }
1132
+ ],
1133
+ "source": [
1134
+ "messages = test_dataset[0]['prompt']\n",
1135
+ "text = tokenizer.apply_chat_template(\n",
1136
+ " messages, add_generation_prompt=True, tokenize=False\n",
1137
+ ")\n",
1138
+ "model_inputs = tokenizer([text], return_tensors=\"pt\").to(base_model.device)\n",
1139
+ "\n",
1140
+ "generated_ids = base_model.generate(\n",
1141
+ " **model_inputs,\n",
1142
+ " max_new_tokens=256\n",
1143
+ ")\n",
1144
+ "output_ids = generated_ids[0][len(model_inputs.input_ids[0]):]\n",
1145
+ "\n",
1146
+ "# Decode and extract model response\n",
1147
+ "generated_text = tokenizer.decode(output_ids, skip_special_tokens=True)\n",
1148
+ "print(generated_text)"
1149
+ ]
1150
+ },
1151
+ {
1152
+ "cell_type": "markdown",
1153
+ "metadata": {
1154
+ "id": "V9eoUwQS8SIi"
1155
+ },
1156
+ "source": [
1157
+ "The base model neither produced reasoning traces nor provided a correct answer. Let's now load the fine-tuned model and check its performance."
1158
+ ]
1159
+ },
1160
+ {
1161
+ "cell_type": "code",
1162
+ "execution_count": null,
1163
+ "metadata": {
1164
+ "colab": {
1165
+ "referenced_widgets": [
1166
+ "073b351afd264bf0bf23043b37e0d8ce",
1167
+ "3dee429faf4e40b192cabebfe4bf2245"
1168
+ ]
1169
+ },
1170
+ "id": "CNannsXXWbdt",
1171
+ "outputId": "fc43a5b9-4ec6-43eb-fc34-f26e92434faf"
1172
+ },
1173
+ "outputs": [
1174
+ {
1175
+ "data": {
1176
+ "application/vnd.jupyter.widget-view+json": {
1177
+ "model_id": "073b351afd264bf0bf23043b37e0d8ce",
1178
+ "version_major": 2,
1179
+ "version_minor": 0
1180
+ },
1181
+ "text/plain": [
1182
+ "adapter_config.json: 0.00B [00:00, ?B/s]"
1183
+ ]
1184
+ },
1185
+ "metadata": {},
1186
+ "output_type": "display_data"
1187
+ },
1188
+ {
1189
+ "data": {
1190
+ "application/vnd.jupyter.widget-view+json": {
1191
+ "model_id": "3dee429faf4e40b192cabebfe4bf2245",
1192
+ "version_major": 2,
1193
+ "version_minor": 0
1194
+ },
1195
+ "text/plain": [
1196
+ "adapter_model.safetensors: 0%| | 0.00/162M [00:00<?, ?B/s]"
1197
+ ]
1198
+ },
1199
+ "metadata": {},
1200
+ "output_type": "display_data"
1201
+ }
1202
+ ],
1203
+ "source": [
1204
+ "fine_tuned_model = PeftModel.from_pretrained(base_model, adapter_model)"
1205
+ ]
1206
+ },
1207
+ {
1208
+ "cell_type": "code",
1209
+ "execution_count": null,
1210
+ "metadata": {
1211
+ "id": "3yOJ82F9Wbdt",
1212
+ "outputId": "f7b2d716-0ded-4ba4-9534-0481e81b4a15"
1213
+ },
1214
+ "outputs": [
1215
+ {
1216
+ "name": "stdout",
1217
+ "output_type": "stream",
1218
+ "text": [
1219
+ "<think> I need to find a birth year where the sum of its digits equals the person's age in 1988 </think>\n",
1220
+ "\n",
1221
+ "The person would have been born in 1979, since 1+9+7+9 = 26 and 26 is the age in 1988\n",
1222
+ "\n",
1223
+ "answer: 26\n"
1224
+ ]
1225
+ }
1226
+ ],
1227
+ "source": [
1228
+ "text = tokenizer.apply_chat_template(\n",
1229
+ " messages, add_generation_prompt=True, tokenize=False\n",
1230
+ ")\n",
1231
+ "model_inputs = tokenizer([text], return_tensors=\"pt\").to(fine_tuned_model.device)\n",
1232
+ "\n",
1233
+ "generated_ids = fine_tuned_model.generate(\n",
1234
+ " **model_inputs,\n",
1235
+ " max_new_tokens=256\n",
1236
+ ")\n",
1237
+ "output_ids = generated_ids[0][len(model_inputs.input_ids[0]):]\n",
1238
+ "\n",
1239
+ "# Decode and extract model response\n",
1240
+ "generated_text = tokenizer.decode(output_ids, skip_special_tokens=True)\n",
1241
+ "print(generated_text)"
1242
+ ]
1243
+ },
1244
+ {
1245
+ "cell_type": "markdown",
1246
+ "metadata": {
1247
+ "id": "OU-xDHpEEmg9"
1248
+ },
1249
+ "source": [
1250
+ "The final answer is correct!"
1251
+ ]
1252
+ },
1253
+ {
1254
+ "cell_type": "markdown",
1255
+ "metadata": {
1256
+ "id": "XNtBOpRY8a2O"
1257
+ },
1258
+ "source": [
1259
+ "## Inference and Serving with vLLM\n",
1260
+ "\n",
1261
+ "You can use Transformer models with **vLLM** to serve them in real-world applications. Learn more [here](https://blog.vllm.ai/2025/04/11/transformers-backend.html)."
1262
+ ]
1263
+ },
1264
+ {
1265
+ "cell_type": "markdown",
1266
+ "metadata": {
1267
+ "id": "nkhu0uY78lV3"
1268
+ },
1269
+ "source": [
1270
+ "### Push Merged Model (for LoRA or QLoRA Training)\n",
1271
+ "\n",
1272
+ "To serve the model via **vLLM**, the repository must contain the merged model (base model + LoRA adapter). Therefore, you need to upload it first."
1273
+ ]
1274
+ },
1275
+ {
1276
+ "cell_type": "code",
1277
+ "execution_count": null,
1278
+ "metadata": {
1279
+ "id": "NF8ZP9Z-Wbdt",
1280
+ "outputId": "32a5ab71-1f0d-4289-ea12-66f5f75a957b"
1281
+ },
1282
+ "outputs": [
1283
+ {
1284
+ "data": {
1285
+ "text/plain": [
1286
+ "('Qwen2-7B-Instruct-GRPO-merged/tokenizer_config.json',\n",
1287
+ " 'Qwen2-7B-Instruct-GRPO-merged/special_tokens_map.json',\n",
1288
+ " 'Qwen2-7B-Instruct-GRPO-merged/chat_template.jinja',\n",
1289
+ " 'Qwen2-7B-Instruct-GRPO-merged/vocab.json',\n",
1290
+ " 'Qwen2-7B-Instruct-GRPO-merged/merges.txt',\n",
1291
+ " 'Qwen2-7B-Instruct-GRPO-merged/added_tokens.json',\n",
1292
+ " 'Qwen2-7B-Instruct-GRPO-merged/tokenizer.json')"
1293
+ ]
1294
+ },
1295
+ "execution_count": 29,
1296
+ "metadata": {},
1297
+ "output_type": "execute_result"
1298
+ }
1299
+ ],
1300
+ "source": [
1301
+ "model_merged = fine_tuned_model.merge_and_unload()\n",
1302
+ "\n",
1303
+ "save_dir = f\"{output_dir}-merged\"\n",
1304
+ "\n",
1305
+ "model_merged.save_pretrained(save_dir)\n",
1306
+ "tokenizer.save_pretrained(save_dir)"
1307
+ ]
1308
+ },
1309
+ {
1310
+ "cell_type": "code",
1311
+ "execution_count": null,
1312
+ "metadata": {
1313
+ "colab": {
1314
+ "referenced_widgets": [
1315
+ "d1a0574cc20046d5876cf31b21955f8b",
1316
+ "7cc2f0ef7ad2494cad572cd898095c00",
1317
+ "475420d92bb54dc08517ffe423b015c3",
1318
+ "a76231aeae5a49979d1e9075b0b3eefb",
1319
+ "b4f469f957134ea9b0e28532fe3caaf1",
1320
+ "637e55736da34f2c9b098222ae07244a",
1321
+ "8157e521017c450a9d2a9e41611405e9",
1322
+ "9746ae4ab0574ed186f898dba3b4b197",
1323
+ "d4b2a8805ec548ea85e0900ff5927574",
1324
+ "0668cd8597f141e89ef38129c6641c1f"
1325
+ ]
1326
+ },
1327
+ "id": "X5Zci39rWbdt",
1328
+ "outputId": "ca329f99-dc7b-470c-f5d9-39a3eabcb16d"
1329
+ },
1330
+ "outputs": [
1331
+ {
1332
+ "data": {
1333
+ "application/vnd.jupyter.widget-view+json": {
1334
+ "model_id": "d1a0574cc20046d5876cf31b21955f8b",
1335
+ "version_major": 2,
1336
+ "version_minor": 0
1337
+ },
1338
+ "text/plain": [
1339
+ "Processing Files (0 / 0) : | | 0.00B / 0.00B "
1340
+ ]
1341
+ },
1342
+ "metadata": {},
1343
+ "output_type": "display_data"
1344
+ },
1345
+ {
1346
+ "data": {
1347
+ "application/vnd.jupyter.widget-view+json": {
1348
+ "model_id": "7cc2f0ef7ad2494cad572cd898095c00",
1349
+ "version_major": 2,
1350
+ "version_minor": 0
1351
+ },
1352
+ "text/plain": [
1353
+ "New Data Upload : | | 0.00B / 0.00B "
1354
+ ]
1355
+ },
1356
+ "metadata": {},
1357
+ "output_type": "display_data"
1358
+ },
1359
+ {
1360
+ "data": {
1361
+ "application/vnd.jupyter.widget-view+json": {
1362
+ "model_id": "475420d92bb54dc08517ffe423b015c3",
1363
+ "version_major": 2,
1364
+ "version_minor": 0
1365
+ },
1366
+ "text/plain": [
1367
+ " ...0002-of-00004.safetensors: 0%| | 612kB / 4.93GB "
1368
+ ]
1369
+ },
1370
+ "metadata": {},
1371
+ "output_type": "display_data"
1372
+ },
1373
+ {
1374
+ "data": {
1375
+ "application/vnd.jupyter.widget-view+json": {
1376
+ "model_id": "a76231aeae5a49979d1e9075b0b3eefb",
1377
+ "version_major": 2,
1378
+ "version_minor": 0
1379
+ },
1380
+ "text/plain": [
1381
+ " ...0003-of-00004.safetensors: 0%| | 611kB / 4.33GB "
1382
+ ]
1383
+ },
1384
+ "metadata": {},
1385
+ "output_type": "display_data"
1386
+ },
1387
+ {
1388
+ "data": {
1389
+ "application/vnd.jupyter.widget-view+json": {
1390
+ "model_id": "b4f469f957134ea9b0e28532fe3caaf1",
1391
+ "version_major": 2,
1392
+ "version_minor": 0
1393
+ },
1394
+ "text/plain": [
1395
+ " ...0001-of-00004.safetensors: 1%|1 | 50.3MB / 4.88GB "
1396
+ ]
1397
+ },
1398
+ "metadata": {},
1399
+ "output_type": "display_data"
1400
+ },
1401
+ {
1402
+ "data": {
1403
+ "application/vnd.jupyter.widget-view+json": {
1404
+ "model_id": "637e55736da34f2c9b098222ae07244a",
1405
+ "version_major": 2,
1406
+ "version_minor": 0
1407
+ },
1408
+ "text/plain": [
1409
+ " ...0004-of-00004.safetensors: 4%|3 | 41.9MB / 1.09GB "
1410
+ ]
1411
+ },
1412
+ "metadata": {},
1413
+ "output_type": "display_data"
1414
+ },
1415
+ {
1416
+ "data": {
1417
+ "application/vnd.jupyter.widget-view+json": {
1418
+ "model_id": "8157e521017c450a9d2a9e41611405e9",
1419
+ "version_major": 2,
1420
+ "version_minor": 0
1421
+ },
1422
+ "text/plain": [
1423
+ "README.md: 0.00B [00:00, ?B/s]"
1424
+ ]
1425
+ },
1426
+ "metadata": {},
1427
+ "output_type": "display_data"
1428
+ },
1429
+ {
1430
+ "data": {
1431
+ "application/vnd.jupyter.widget-view+json": {
1432
+ "model_id": "9746ae4ab0574ed186f898dba3b4b197",
1433
+ "version_major": 2,
1434
+ "version_minor": 0
1435
+ },
1436
+ "text/plain": [
1437
+ "Processing Files (0 / 0) : | | 0.00B / 0.00B "
1438
+ ]
1439
+ },
1440
+ "metadata": {},
1441
+ "output_type": "display_data"
1442
+ },
1443
+ {
1444
+ "data": {
1445
+ "application/vnd.jupyter.widget-view+json": {
1446
+ "model_id": "d4b2a8805ec548ea85e0900ff5927574",
1447
+ "version_major": 2,
1448
+ "version_minor": 0
1449
+ },
1450
+ "text/plain": [
1451
+ "New Data Upload : | | 0.00B / 0.00B "
1452
+ ]
1453
+ },
1454
+ "metadata": {},
1455
+ "output_type": "display_data"
1456
+ },
1457
+ {
1458
+ "data": {
1459
+ "application/vnd.jupyter.widget-view+json": {
1460
+ "model_id": "0668cd8597f141e89ef38129c6641c1f",
1461
+ "version_major": 2,
1462
+ "version_minor": 0
1463
+ },
1464
+ "text/plain": [
1465
+ " ...RPO-merged/tokenizer.json: 100%|##########| 11.4MB / 11.4MB "
1466
+ ]
1467
+ },
1468
+ "metadata": {},
1469
+ "output_type": "display_data"
1470
+ },
1471
+ {
1472
+ "data": {
1473
+ "application/vnd.google.colaboratory.intrinsic+json": {
1474
+ "type": "string"
1475
+ },
1476
+ "text/plain": [
1477
+ "CommitInfo(commit_url='https://huggingface.co/sergiopaniego/Qwen2-7B-Instruct-GRPO-merged/commit/b20988444532e79a6915f0b2b6002b5acc2b53e1', commit_message='Upload tokenizer', commit_description='', oid='b20988444532e79a6915f0b2b6002b5acc2b53e1', pr_url=None, repo_url=RepoUrl('https://huggingface.co/sergiopaniego/Qwen2-7B-Instruct-GRPO-merged', endpoint='https://huggingface.co', repo_type='model', repo_id='sergiopaniego/Qwen2-7B-Instruct-GRPO-merged'), pr_revision=None, pr_num=None)"
1478
+ ]
1479
+ },
1480
+ "execution_count": 30,
1481
+ "metadata": {},
1482
+ "output_type": "execute_result"
1483
+ }
1484
+ ],
1485
+ "source": [
1486
+ "model_merged.push_to_hub(f\"sergiopaniego/{output_dir}-merged\") # Replace with your HF username or organization\n",
1487
+ "tokenizer.push_to_hub(f\"sergiopaniego/{output_dir}-merged\") # Replace with your HF username or organization"
1488
+ ]
1489
+ },
1490
+ {
1491
+ "cell_type": "markdown",
1492
+ "metadata": {
1493
+ "id": "DQ00Ivxi8rFu"
1494
+ },
1495
+ "source": [
1496
+ "### Performing Inference with vLLM\n",
1497
+ "\n",
1498
+ "Use **vLLM** to run your model and generate text efficiently in real-time. This allows you to test and deploy your fine-tuned models with low latency and high throughput."
1499
+ ]
1500
+ },
1501
+ {
1502
+ "cell_type": "code",
1503
+ "execution_count": null,
1504
+ "metadata": {
1505
+ "id": "x7L-HIn4Wbdt",
1506
+ "outputId": "afd66093-3525-4590-f834-c0b373e7bb9e"
1507
+ },
1508
+ "outputs": [
1509
+ {
1510
+ "name": "stdout",
1511
+ "output_type": "stream",
1512
+ "text": [
1513
+ "INFO 12-11 15:56:09 [utils.py:253] non-default args: {'dtype': torch.float16, 'max_model_len': 256, 'disable_log_stats': True, 'model_impl': 'transformers', 'model': 'sergiopaniego/Qwen2-7B-Instruct-GRPO-merged'}\n"
1514
+ ]
1515
+ },
1516
+ {
1517
+ "name": "stderr",
1518
+ "output_type": "stream",
1519
+ "text": [
1520
+ "/usr/local/lib/python3.12/dist-packages/huggingface_hub/utils/_auth.py:104: UserWarning: \n",
1521
+ "Error while fetching `HF_TOKEN` secret value from your vault: 'Requesting secret HF_TOKEN timed out. Secrets can only be fetched when running from the Colab UI.'.\n",
1522
+ "You are not authenticated with the Hugging Face Hub in this notebook.\n",
1523
+ "If the error persists, please let us know by opening an issue on GitHub (https://github.com/huggingface/huggingface_hub/issues/new).\n",
1524
+ " warnings.warn(\n"
1525
+ ]
1526
+ },
1527
+ {
1528
+ "name": "stdout",
1529
+ "output_type": "stream",
1530
+ "text": [
1531
+ "INFO 12-11 15:56:37 [model.py:631] Resolved architecture: TransformersForCausalLM\n",
1532
+ "WARNING 12-11 15:56:37 [model.py:1971] Casting torch.bfloat16 to torch.float16.\n",
1533
+ "INFO 12-11 15:56:37 [model.py:1745] Using max model len 256\n",
1534
+ "INFO 12-11 15:56:40 [scheduler.py:216] Chunked prefill is enabled with max_num_batched_tokens=8192.\n",
1535
+ "WARNING 12-11 15:56:43 [system_utils.py:103] We must use the `spawn` multiprocessing start method. Overriding VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. See https://docs.vllm.ai/en/latest/usage/troubleshooting.html#python-multiprocessing for more information. Reasons: CUDA is initialized\n",
1536
+ "INFO 12-11 15:57:36 [llm.py:352] Supported tasks: ['generate']\n"
1537
+ ]
1538
+ }
1539
+ ],
1540
+ "source": [
1541
+ "from vllm import LLM, SamplingParams\n",
1542
+ "from transformers import AutoTokenizer\n",
1543
+ "import torch\n",
1544
+ "\n",
1545
+ "llm = LLM(\n",
1546
+ " model=f\"sergiopaniego/{output_dir}-merged\", # Replace with your HF username or organization\n",
1547
+ " model_impl=\"transformers\", # Select the transformers model implementation\n",
1548
+ " max_model_len=256, # Reduced for efficiency\n",
1549
+ " dtype=torch.float16\n",
1550
+ ")\n",
1551
+ "hf_tokenizer = AutoTokenizer.from_pretrained(f\"sergiopaniego/{output_dir}-merged\") # Replace with your HF username or organization"
1552
+ ]
1553
+ },
1554
+ {
1555
+ "cell_type": "code",
1556
+ "execution_count": null,
1557
+ "metadata": {
1558
+ "colab": {
1559
+ "referenced_widgets": [
1560
+ "f0a4f4fb17bf4a698503212296467547",
1561
+ "5be7348f3f324b5b9397c9ad186fb35d"
1562
+ ]
1563
+ },
1564
+ "id": "ZTpSUqxNWbdt",
1565
+ "outputId": "6a9283bf-d3b7-4e54-c775-4502694b5c6d"
1566
+ },
1567
+ "outputs": [
1568
+ {
1569
+ "data": {
1570
+ "application/vnd.jupyter.widget-view+json": {
1571
+ "model_id": "f0a4f4fb17bf4a698503212296467547",
1572
+ "version_major": 2,
1573
+ "version_minor": 0
1574
+ },
1575
+ "text/plain": [
1576
+ "Adding requests: 0%| | 0/1 [00:00<?, ?it/s]"
1577
+ ]
1578
+ },
1579
+ "metadata": {},
1580
+ "output_type": "display_data"
1581
+ },
1582
+ {
1583
+ "data": {
1584
+ "application/vnd.jupyter.widget-view+json": {
1585
+ "model_id": "5be7348f3f324b5b9397c9ad186fb35d",
1586
+ "version_major": 2,
1587
+ "version_minor": 0
1588
+ },
1589
+ "text/plain": [
1590
+ "Processed prompts: 0%| | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]"
1591
+ ]
1592
+ },
1593
+ "metadata": {},
1594
+ "output_type": "display_data"
1595
+ },
1596
+ {
1597
+ "name": "stdout",
1598
+ "output_type": "stream",
1599
+ "text": [
1600
+ "<think> 1988 birth year implies the person was born either in 1979, 1980, 1981, etc. Looking for the one where sum of digits equals age </think>\n",
1601
+ "\n",
1602
+ "The birth year 1979 gives sum of digits 1+9+7+9 = 26\n",
1603
+ "\n",
1604
+ "The person was 26 years old in 1988.\n",
1605
+ "\n",
1606
+ "Answer: The person was 26 years old.\n"
1607
+ ]
1608
+ }
1609
+ ],
1610
+ "source": [
1611
+ "messages = test_dataset[0]['prompt']\n",
1612
+ "# Alternatively, use llm.chat()\n",
1613
+ "prompt = hf_tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)\n",
1614
+ "\n",
1615
+ "outputs = llm.generate(\n",
1616
+ " {\"prompt\": prompt},\n",
1617
+ " sampling_params=SamplingParams(max_tokens=256),\n",
1618
+ ")\n",
1619
+ "\n",
1620
+ "for o in outputs:\n",
1621
+ " generated_text = o.outputs[0].text\n",
1622
+ " print(generated_text)"
1623
+ ]
1624
+ }
1625
+ ],
1626
+ "metadata": {
1627
+ "accelerator": "GPU",
1628
+ "colab": {
1629
+ "gpuType": "T4",
1630
+ "provenance": []
1631
+ },
1632
+ "language_info": {
1633
+ "name": "python"
1634
+ }
1635
+ },
1636
+ "nbformat": 4,
1637
+ "nbformat_minor": 0
1638
+ }
ICL/RL/trl_source/examples/notebooks/openenv_sudoku_grpo.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
ICL/RL/trl_source/examples/notebooks/openenv_wordle_grpo.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
ICL/RL/trl_source/examples/notebooks/sft_ministral3_vl.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
ICL/RL/trl_source/examples/notebooks/sft_qwen_vl.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
ICL/RL/trl_source/examples/notebooks/sft_trl_lora_qlora.ipynb ADDED
@@ -0,0 +1,1140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "id": "5oqSnSaqLWAL"
7
+ },
8
+ "source": [
9
+ "# Supervised Fine-Tuning (SFT) with LoRA/QLoRA using TRL — on a Free Colab Notebook\n",
10
+ "\n",
11
+ "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/notebooks/sft_trl_lora_qlora.ipynb)"
12
+ ]
13
+ },
14
+ {
15
+ "cell_type": "markdown",
16
+ "metadata": {
17
+ "id": "d6c1x17tLWAR"
18
+ },
19
+ "source": [
20
+ "![trl banner](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl_banner_dark.png)"
21
+ ]
22
+ },
23
+ {
24
+ "cell_type": "markdown",
25
+ "metadata": {
26
+ "id": "cQ6bxQaMLWAS"
27
+ },
28
+ "source": [
29
+ "Easily fine-tune Large Language Models (LLMs) or Vision-Language Models (VLMs) with **LoRA** or **QLoRA** using the [**Transformers Reinforcement Learning (TRL)**](https://github.com/huggingface/trl) library built by Hugging Face — all within a **free Google Colab notebook** (powered by a **T4 GPU**.). \n",
30
+ "\n",
31
+ "- [TRL GitHub Repository](https://github.com/huggingface/trl) — star us to support the project! \n",
32
+ "- [Official TRL Examples](https://huggingface.co/docs/trl/example_overview) \n",
33
+ "- [Community Tutorials](https://huggingface.co/docs/trl/community_tutorials)"
34
+ ]
35
+ },
36
+ {
37
+ "cell_type": "markdown",
38
+ "metadata": {
39
+ "id": "JG3wax0uLWAU"
40
+ },
41
+ "source": [
42
+ "## Key concepts\n",
43
+ "\n",
44
+ "- **SFT**: Trains models from example input-output pairs to align behavior with human preferences.\n",
45
+ "- **LoRA**: Updates only a few low-rank parameters, reducing training cost and memory.\n",
46
+ "- **QLoRA**: A quantized version of LoRA that enables even larger models to fit on small GPUs.\n",
47
+ "- **TRL**: The Hugging Face library that makes fine-tuning and reinforcement learning simple and efficient.\n",
48
+ "\n",
49
+ "Learn how to perform **Supervised Fine-Tuning (SFT)** with **LoRA/QLoRA** using **TRL**."
50
+ ]
51
+ },
52
+ {
53
+ "cell_type": "markdown",
54
+ "metadata": {
55
+ "id": "0ZhyNnhiLWAV"
56
+ },
57
+ "source": [
58
+ "## Install dependencies\n",
59
+ "\n",
60
+ "We'll install **TRL** with the **PEFT** extra, which ensures all main dependencies such as **Transformers** and **PEFT** (a package for parameter-efficient fine-tuning, e.g., LoRA/QLoRA) are included. Additionally, we'll install **trackio** to log and monitor our experiments, and **bitsandbytes** to enable quantization of LLMs, reducing memory consumption for both inference and training."
61
+ ]
62
+ },
63
+ {
64
+ "cell_type": "code",
65
+ "execution_count": null,
66
+ "metadata": {
67
+ "id": "FXTyVTJcLWAV"
68
+ },
69
+ "outputs": [],
70
+ "source": [
71
+ "!pip install -Uq \"trl[peft]\" trackio bitsandbytes liger-kernel"
72
+ ]
73
+ },
74
+ {
75
+ "cell_type": "markdown",
76
+ "metadata": {
77
+ "id": "OqlMF6oWLWAY"
78
+ },
79
+ "source": [
80
+ "### Log in to Hugging Face"
81
+ ]
82
+ },
83
+ {
84
+ "cell_type": "markdown",
85
+ "metadata": {
86
+ "id": "2blL6-1_LWAa"
87
+ },
88
+ "source": [
89
+ "Log in to your **Hugging Face** account to save your fine-tuned model, track your experiment results directly on the Hub or access gated models. You can find your **access token** on your [account settings page](https://huggingface.co/settings/tokens)."
90
+ ]
91
+ },
92
+ {
93
+ "cell_type": "code",
94
+ "execution_count": null,
95
+ "metadata": {
96
+ "id": "6OMeJOp7LWAc"
97
+ },
98
+ "outputs": [],
99
+ "source": [
100
+ "from huggingface_hub import notebook_login\n",
101
+ "\n",
102
+ "notebook_login()"
103
+ ]
104
+ },
105
+ {
106
+ "cell_type": "markdown",
107
+ "metadata": {
108
+ "id": "6HHscLIQLWAd"
109
+ },
110
+ "source": [
111
+ "## Load Dataset\n",
112
+ "\n",
113
+ "In this step, we load the [**HuggingFaceH4/Multilingual-Thinking**](https://huggingface.co/datasets/HuggingFaceH4/Multilingual-Thinking) dataset from the Hugging Face Hub using the `datasets` library. \n",
114
+ "This dataset focuses on **multilingual reasoning**, where the *chain of thought* has been translated into several languages such as French, Spanish, and German. \n",
115
+ "By fine-tuning a reasoning-capable model on this dataset, it learns to **generate reasoning steps in multiple languages**, making its thought process more **interpretable and accessible** to non-English speakers.\n",
116
+ "\n",
117
+ "> 💡 This dataset is best suited for models that already demonstrate reasoning capabilities. \n",
118
+ "> If you're using a model without reasoning skills, consider choosing a different dataset. Example: [`trl-lib/llava-instruct-mix`](https://huggingface.co/datasets/trl-lib/llava-instruct-mix).\n",
119
+ "\n",
120
+ "For efficiency, we'll load only the **training split**:"
121
+ ]
122
+ },
123
+ {
124
+ "cell_type": "code",
125
+ "execution_count": null,
126
+ "metadata": {
127
+ "id": "dlQSKxTnLWAd"
128
+ },
129
+ "outputs": [],
130
+ "source": [
131
+ "from datasets import load_dataset\n",
132
+ "\n",
133
+ "dataset_name = \"HuggingFaceH4/Multilingual-Thinking\"\n",
134
+ "train_dataset = load_dataset(dataset_name, split=\"train\")"
135
+ ]
136
+ },
137
+ {
138
+ "cell_type": "markdown",
139
+ "metadata": {
140
+ "id": "bRHTwwZXLWAe"
141
+ },
142
+ "source": [
143
+ "This dataset contains different columns. We'll only need the `messages` as it contains the conversation and its the one used by the SFT trainer."
144
+ ]
145
+ },
146
+ {
147
+ "cell_type": "code",
148
+ "execution_count": null,
149
+ "metadata": {
150
+ "id": "zOBq8tVdLWAe",
151
+ "outputId": "e12ab8ae-e00c-4e89-b489-dd448db8e13b"
152
+ },
153
+ "outputs": [
154
+ {
155
+ "data": {
156
+ "text/plain": [
157
+ "Dataset({\n",
158
+ " features: ['reasoning_language', 'developer', 'user', 'analysis', 'final', 'messages'],\n",
159
+ " num_rows: 1000\n",
160
+ "})"
161
+ ]
162
+ },
163
+ "execution_count": null,
164
+ "metadata": {},
165
+ "output_type": "execute_result"
166
+ }
167
+ ],
168
+ "source": [
169
+ "train_dataset"
170
+ ]
171
+ },
172
+ {
173
+ "cell_type": "markdown",
174
+ "metadata": {
175
+ "id": "b13TjFs2LWAe"
176
+ },
177
+ "source": [
178
+ "Let's see a full example to understand the internal structure:"
179
+ ]
180
+ },
181
+ {
182
+ "cell_type": "code",
183
+ "execution_count": null,
184
+ "metadata": {
185
+ "id": "ZON5mIMNLWAf",
186
+ "outputId": "d01415eb-26cb-45ce-ad48-0388161eea28"
187
+ },
188
+ "outputs": [
189
+ {
190
+ "data": {
191
+ "text/plain": [
192
+ "{'reasoning_language': 'French',\n",
193
+ " 'developer': 'You are an AI chatbot with a lively and energetic personality.',\n",
194
+ " 'user': 'Can you show me the latest trends on Twitter right now?',\n",
195
+ " 'analysis': \"D'accord, l'utilisateur demande les tendances Twitter les plus récentes. Tout d'abord, je dois vérifier si j'ai accès à des données en temps réel. Étant donné que je ne peux pas naviguer sur Internet ou accéder directement à l'API de Twitter, je ne peux pas fournir des tendances en direct. Cependant, je peux donner quelques conseils généraux sur la façon de les trouver.\\n\\nJe devrais préciser que les tendances Twitter évoluent rapidement et sont spécifiques à chaque région. Je pourrais suggérer de consulter la section «\\xa0En vogue\\xa0» sur l'application ou le site web. Aussi, l'utilisation de hashtags et le suivi d'utilisateurs pertinents pourraient être utiles. Il est important de souligner que les tendances varient selon la région et l'heure de la journée. Je devrais garder un ton amical et bienveillant, peut-être ajouter un emoji pour rester léger. Je vais structurer ma réponse étape par étape pour faciliter la lecture. Je dois m'excuser de ne pas pouvoir fournir des données en temps réel et proposer d'autres méthodes. Je conserverai un langage simple et convivial, en évitant les termes techniques.\",\n",
196
+ " 'final': 'Hey there! While I can\\'t check Twitter (X) in real-time or access live data, I can share some tips to help you spot the latest trends:\\n\\n1. **Open the \"Trending\" tab** on the Twitter app or website – it updates constantly! \\n2. **Search for hashtags** like #Trending or #Viral to see what’s blowing up. \\n3. **Follow accounts** that curate trends (e.g., @TrendingNow, @ViralThreads). \\n4. **Check regional trends** – they often differ by location! \\n\\nRemember, trends are *super fast-moving* and often tied to pop culture, memes, or breaking news. For example, recent trends have included viral challenges (like the \"Distracted Boyfriend\" meme revival), celebrity drama, or unexpected events (hello, weather disasters!). \\n\\nWant me to brainstorm *what* might trend next? I’ve got ideas!',\n",
197
+ " 'messages': [{'content': 'reasoning language: French\\n\\nYou are an AI chatbot with a lively and energetic personality.',\n",
198
+ " 'role': 'system',\n",
199
+ " 'thinking': None},\n",
200
+ " {'content': 'Can you show me the latest trends on Twitter right now?',\n",
201
+ " 'role': 'user',\n",
202
+ " 'thinking': None},\n",
203
+ " {'content': 'Hey there! While I can\\'t check Twitter (X) in real-time or access live data, I can share some tips to help you spot the latest trends:\\n\\n1. **Open the \"Trending\" tab** on the Twitter app or website – it updates constantly! \\n2. **Search for hashtags** like #Trending or #Viral to see what’s blowing up. \\n3. **Follow accounts** that curate trends (e.g., @TrendingNow, @ViralThreads). \\n4. **Check regional trends** – they often differ by location! \\n\\nRemember, trends are *super fast-moving* and often tied to pop culture, memes, or breaking news. For example, recent trends have included viral challenges (like the \"Distracted Boyfriend\" meme revival), celebrity drama, or unexpected events (hello, weather disasters!). \\n\\nWant me to brainstorm *what* might trend next? I’ve got ideas!',\n",
204
+ " 'role': 'assistant',\n",
205
+ " 'thinking': \"D'accord, l'utilisateur demande les tendances Twitter les plus récentes. Tout d'abord, je dois vérifier si j'ai accès à des données en temps réel. Étant donné que je ne peux pas naviguer sur Internet ou accéder directement à l'API de Twitter, je ne peux pas fournir des tendances en direct. Cependant, je peux donner quelques conseils généraux sur la façon de les trouver.\\n\\nJe devrais préciser que les tendances Twitter évoluent rapidement et sont spécifiques à chaque région. Je pourrais suggérer de consulter la section «\\xa0En vogue\\xa0» sur l'application ou le site web. Aussi, l'utilisation de hashtags et le suivi d'utilisateurs pertinents pourraient être utiles. Il est important de souligner que les tendances varient selon la région et l'heure de la journée. Je devrais garder un ton amical et bienveillant, peut-être ajouter un emoji pour rester léger. Je vais structurer ma réponse étape par étape pour faciliter la lecture. Je dois m'excuser de ne pas pouvoir fournir des données en temps réel et proposer d'autres méthodes. Je conserverai un langage simple et convivial, en évitant les termes techniques.\"}]}"
206
+ ]
207
+ },
208
+ "execution_count": null,
209
+ "metadata": {},
210
+ "output_type": "execute_result"
211
+ }
212
+ ],
213
+ "source": [
214
+ "train_dataset[0]"
215
+ ]
216
+ },
217
+ {
218
+ "cell_type": "markdown",
219
+ "metadata": {
220
+ "id": "RPQfGZjlLWAf"
221
+ },
222
+ "source": [
223
+ "\n",
224
+ "Now, let's remove the columns that are not needed, as we just discussed:"
225
+ ]
226
+ },
227
+ {
228
+ "cell_type": "code",
229
+ "execution_count": null,
230
+ "metadata": {
231
+ "id": "pCM6PoIzLWAf"
232
+ },
233
+ "outputs": [],
234
+ "source": [
235
+ "train_dataset = train_dataset.remove_columns(column_names=['reasoning_language', 'developer', 'user', 'analysis', 'final'])"
236
+ ]
237
+ },
238
+ {
239
+ "cell_type": "markdown",
240
+ "metadata": {
241
+ "id": "BcU6E8KnLWAf"
242
+ },
243
+ "source": [
244
+ "The `messages` column is specifically formatted according to the [Harmony response format](https://cookbook.openai.com/articles/openai-harmony) used by *gpt-oss*. \n",
245
+ "In our case, we'll need to simplify it slightly, since our model's chat template doesn't include a dedicated `thinking` section (check [this example](https://cookbook.openai.com/articles/gpt-oss/fine-tune-transfomers) for more details). \n",
246
+ "To adapt it, we'll merge that part into the message content using the standard `<think>...</think>` tags.\n"
247
+ ]
248
+ },
249
+ {
250
+ "cell_type": "code",
251
+ "execution_count": null,
252
+ "metadata": {
253
+ "id": "XQ2xYEq3LWAf"
254
+ },
255
+ "outputs": [],
256
+ "source": [
257
+ "def merge_thinking_and_remove_key(example):\n",
258
+ " new_messages = []\n",
259
+ " for msg in example[\"messages\"]:\n",
260
+ " content = msg[\"content\"]\n",
261
+ " thinking = msg.pop(\"thinking\", None)\n",
262
+ " if thinking and isinstance(thinking, str) and thinking.strip():\n",
263
+ " content = f\"<think>\\n{thinking}\\n</think>\\n{content}\"\n",
264
+ " msg[\"content\"] = content\n",
265
+ " new_messages.append(msg)\n",
266
+ " example[\"messages\"] = new_messages\n",
267
+ " return example\n",
268
+ "\n",
269
+ "train_dataset = train_dataset.map(merge_thinking_and_remove_key)"
270
+ ]
271
+ },
272
+ {
273
+ "cell_type": "markdown",
274
+ "metadata": {
275
+ "id": "ewvZeKUcLWAf"
276
+ },
277
+ "source": [
278
+ "## Load model and configure LoRA/QLoRA\n",
279
+ "\n",
280
+ "This notebook can be used with two fine-tuning methods. By default, it is set up for **QLoRA**, which includes quantization using `BitsAndBytesConfig`. If you prefer to use standard **LoRA** without quantization, simply comment out the `BitsAndBytesConfig` configuration.\n",
281
+ "\n",
282
+ "Below, choose your **preferred model**. All of the options have been tested on **free Colab instances**."
283
+ ]
284
+ },
285
+ {
286
+ "cell_type": "code",
287
+ "execution_count": null,
288
+ "metadata": {
289
+ "id": "sAWjOn9gLWAf"
290
+ },
291
+ "outputs": [],
292
+ "source": [
293
+ "# Select one model below by uncommenting the line you want to use 👇\n",
294
+ "## Qwen\n",
295
+ "model_id, output_dir = \"unsloth/qwen3-14b-unsloth-bnb-4bit\", \"qwen3-14b-unsloth-bnb-4bit-SFT\" # ⚠️ ~14.1 GB VRAM\n",
296
+ "# model_id, output_dir = \"Qwen/Qwen3-8B\", \"Qwen3-8B-SFT\" # ⚠️ ~12.8 GB VRAM\n",
297
+ "# model_id, output_dir = \"Qwen/Qwen2.5-7B-Instruct\", \"Qwen2.5-7B-Instruct\" # ✅ ~10.8 GB VRAM\n",
298
+ "\n",
299
+ "## Llama\n",
300
+ "# model_id, output_dir = \"meta-llama/Llama-3.2-3B-Instruct\", \"Llama-3.2-3B-Instruct\" # ✅ ~4.7 GB VRAM\n",
301
+ "# model_id, output_dir = \"meta-llama/Llama-3.1-8B-Instruct\", \"Llama-3.1-8B-Instruct\" # ⚠️ ~10.9 GB VRAM\n",
302
+ "\n",
303
+ "## Gemma\n",
304
+ "# model_id, output_dir = \"google/gemma-3n-E2B-it\", \"gemma-3n-E2B-it\" # ❌ Upgrade to a higher tier of colab\n",
305
+ "# model_id, output_dir = \"google/gemma-3-4b-it\", \"gemma-3-4b-it\" # ⚠️ ~6.8 GB VRAM\n",
306
+ "\n",
307
+ "## Granite\n",
308
+ "#model_id, output_dir = \"ibm-granite/granite-4.0-micro\", \"granite-4.0-micro\" # ✅ ~3.3 GB VRAM\n",
309
+ "\n",
310
+ "## LFM2\n",
311
+ "#model_id, output_dir = \"LiquidAI/LFM2-2.6B\", \"LFM2-2.6B-SFT\" # ✅ ~5.89 GB VRAM"
312
+ ]
313
+ },
314
+ {
315
+ "cell_type": "markdown",
316
+ "metadata": {
317
+ "id": "BXY9Y0_dLWAf"
318
+ },
319
+ "source": [
320
+ "Let's load the selected model using `transformers`, configuring QLoRA via `bitsandbytes` (you can remove it if doing LoRA). We don't need to configure the tokenizer since the trainer takes care of that automatically."
321
+ ]
322
+ },
323
+ {
324
+ "cell_type": "code",
325
+ "execution_count": null,
326
+ "metadata": {
327
+ "id": "oyOoWFsLLWAg"
328
+ },
329
+ "outputs": [],
330
+ "source": [
331
+ "import torch\n",
332
+ "from transformers import AutoModelForCausalLM, BitsAndBytesConfig\n",
333
+ "\n",
334
+ "model = AutoModelForCausalLM.from_pretrained(\n",
335
+ " model_id,\n",
336
+ " attn_implementation=\"sdpa\", # Change to Flash Attention if GPU has support\n",
337
+ " dtype=torch.float16, # Change to bfloat16 if GPU has support\n",
338
+ " use_cache=True, # Whether to cache attention outputs to speed up inference\n",
339
+ " quantization_config=BitsAndBytesConfig(\n",
340
+ " load_in_4bit=True, # Load the model in 4-bit precision to save memory\n",
341
+ " bnb_4bit_compute_dtype=torch.float16, # Data type used for internal computations in quantization\n",
342
+ " bnb_4bit_use_double_quant=True, # Use double quantization to improve accuracy\n",
343
+ " bnb_4bit_quant_type=\"nf4\" # Type of quantization. \"nf4\" is recommended for recent LLMs\n",
344
+ " )\n",
345
+ ")"
346
+ ]
347
+ },
348
+ {
349
+ "cell_type": "markdown",
350
+ "metadata": {
351
+ "id": "L-_BpOdILWAg"
352
+ },
353
+ "source": [
354
+ "The following cell defines LoRA (or QLoRA if needed). When training with LoRA/QLoRA, we use a **base model** (the one selected above) and, instead of modifying its original weights, we fine-tune a **LoRA adapter** — a lightweight layer that enables efficient and memory-friendly training. The **`target_modules`** specify which parts of the model (e.g., attention or projection layers) will be adapted by LoRA during fine-tuning."
355
+ ]
356
+ },
357
+ {
358
+ "cell_type": "code",
359
+ "execution_count": null,
360
+ "metadata": {
361
+ "id": "9EL-glV-LWAg"
362
+ },
363
+ "outputs": [],
364
+ "source": [
365
+ "from peft import LoraConfig\n",
366
+ "\n",
367
+ "# You may need to update `target_modules` depending on the architecture of your chosen model.\n",
368
+ "# For example, different LLMs might have different attention/projection layer names.\n",
369
+ "peft_config = LoraConfig(\n",
370
+ " r=32,\n",
371
+ " lora_alpha=32,\n",
372
+ " target_modules = [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\",],\n",
373
+ ")"
374
+ ]
375
+ },
376
+ {
377
+ "cell_type": "markdown",
378
+ "metadata": {
379
+ "id": "-i6BMpcaLWAg"
380
+ },
381
+ "source": [
382
+ "## Train model\n",
383
+ "\n",
384
+ "We'll configure **SFT** using `SFTConfig`, keeping the parameters minimal so the training fits on a free Colab instance. You can adjust these settings if more resources are available. For full details on all available parameters, check the [TRL SFTConfig documentation](https://huggingface.co/docs/trl/sft_trainer#trl.SFTConfig)."
385
+ ]
386
+ },
387
+ {
388
+ "cell_type": "code",
389
+ "execution_count": null,
390
+ "metadata": {
391
+ "id": "-doztoyxLWAg"
392
+ },
393
+ "outputs": [],
394
+ "source": [
395
+ "from trl import SFTConfig\n",
396
+ "\n",
397
+ "training_args = SFTConfig(\n",
398
+ " # Training schedule / optimization\n",
399
+ " per_device_train_batch_size = 1, # Batch size per GPU\n",
400
+ " gradient_accumulation_steps = 4, # Gradients are accumulated over multiple steps → effective batch size = 2 * 8 = 16\n",
401
+ " warmup_steps = 5,\n",
402
+ " # num_train_epochs = 1, # Number of full dataset passes. For shorter training, use `max_steps` instead (this case)\n",
403
+ " max_steps = 30,\n",
404
+ " learning_rate = 2e-4, # Learning rate for the optimizer\n",
405
+ " optim = \"paged_adamw_8bit\", # Optimizer\n",
406
+ "\n",
407
+ " # Logging / reporting\n",
408
+ " logging_steps=1, # Log training metrics every N steps\n",
409
+ " report_to=\"trackio\", # Experiment tracking tool\n",
410
+ " trackio_space_id=output_dir, # HF Space where the experiment tracking will be saved\n",
411
+ " output_dir=output_dir, # Where to save model checkpoints and logs\n",
412
+ "\n",
413
+ " max_length=1024, # Maximum input sequence length\n",
414
+ " use_liger_kernel=True, # Enable Liger kernel optimizations for faster training\n",
415
+ " activation_offloading=True, # Offload activations to CPU to reduce GPU memory usage\n",
416
+ "\n",
417
+ " # Hub integration\n",
418
+ " push_to_hub=True, # Automatically push the trained model to the Hugging Face Hub\n",
419
+ " # The model will be saved under your Hub account in the repository named `output_dir`\n",
420
+ "\n",
421
+ ")"
422
+ ]
423
+ },
424
+ {
425
+ "cell_type": "markdown",
426
+ "metadata": {
427
+ "id": "Gz4ggYeeLWAg"
428
+ },
429
+ "source": [
430
+ "Configure the SFT Trainer. We pass the previously configured `training_args`. We don't use eval dataset to maintain memory usage low but you can configure it."
431
+ ]
432
+ },
433
+ {
434
+ "cell_type": "code",
435
+ "execution_count": null,
436
+ "metadata": {
437
+ "id": "8Yx1wkv_LWAg"
438
+ },
439
+ "outputs": [],
440
+ "source": [
441
+ "from trl import SFTTrainer\n",
442
+ "\n",
443
+ "trainer = SFTTrainer(\n",
444
+ " model=model,\n",
445
+ " args=training_args,\n",
446
+ " train_dataset=train_dataset,\n",
447
+ " peft_config=peft_config\n",
448
+ ")"
449
+ ]
450
+ },
451
+ {
452
+ "cell_type": "markdown",
453
+ "metadata": {
454
+ "id": "0MsNw3uLLWAh"
455
+ },
456
+ "source": [
457
+ "Show memory stats before training"
458
+ ]
459
+ },
460
+ {
461
+ "cell_type": "code",
462
+ "execution_count": null,
463
+ "metadata": {
464
+ "id": "YIuBi-ZYLWAh",
465
+ "outputId": "7f381ba0-fe90-4c6f-df0a-938a29be4e9e"
466
+ },
467
+ "outputs": [
468
+ {
469
+ "name": "stdout",
470
+ "output_type": "stream",
471
+ "text": [
472
+ "GPU = Tesla T4. Max memory = 14.741 GB.\n",
473
+ "12.074 GB of memory reserved.\n"
474
+ ]
475
+ }
476
+ ],
477
+ "source": [
478
+ "gpu_stats = torch.cuda.get_device_properties(0)\n",
479
+ "start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n",
480
+ "max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)\n",
481
+ "\n",
482
+ "print(f\"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.\")\n",
483
+ "print(f\"{start_gpu_memory} GB of memory reserved.\")"
484
+ ]
485
+ },
486
+ {
487
+ "cell_type": "markdown",
488
+ "metadata": {
489
+ "id": "_6G6pMGeLWAh"
490
+ },
491
+ "source": [
492
+ "And train!"
493
+ ]
494
+ },
495
+ {
496
+ "cell_type": "code",
497
+ "execution_count": null,
498
+ "metadata": {
499
+ "id": "glj5UPwWLWAh",
500
+ "outputId": "b0a046c7-f76b-42a6-d870-f54470297971"
501
+ },
502
+ "outputs": [
503
+ {
504
+ "name": "stderr",
505
+ "output_type": "stream",
506
+ "text": [
507
+ "The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'bos_token_id': None}.\n"
508
+ ]
509
+ },
510
+ {
511
+ "name": "stdout",
512
+ "output_type": "stream",
513
+ "text": [
514
+ "* Trackio project initialized: huggingface\n",
515
+ "* Trackio metrics will be synced to Hugging Face Dataset: sergiopaniego/qwen3-14b-unsloth-bnb-4bit-SFT-dataset\n",
516
+ "* Creating new space: https://huggingface.co/spaces/sergiopaniego/qwen3-14b-unsloth-bnb-4bit-SFT\n",
517
+ "* View dashboard by going to: https://sergiopaniego-qwen3-14b-unsloth-bnb-4bit-SFT.hf.space/\n"
518
+ ]
519
+ },
520
+ {
521
+ "data": {
522
+ "text/html": [
523
+ "<div><iframe src=\"https://sergiopaniego-qwen3-14b-unsloth-bnb-4bit-SFT.hf.space/\" width=\"100%\" height=\"1000px\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
524
+ ],
525
+ "text/plain": [
526
+ "<IPython.core.display.HTML object>"
527
+ ]
528
+ },
529
+ "metadata": {},
530
+ "output_type": "display_data"
531
+ },
532
+ {
533
+ "name": "stdout",
534
+ "output_type": "stream",
535
+ "text": [
536
+ "* Created new run: sergiopaniego-1761318512\n"
537
+ ]
538
+ },
539
+ {
540
+ "data": {
541
+ "text/html": [
542
+ "\n",
543
+ " <div>\n",
544
+ " \n",
545
+ " <progress value='30' max='30' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
546
+ " [30/30 1:08:22, Epoch 0/1]\n",
547
+ " </div>\n",
548
+ " <table border=\"1\" class=\"dataframe\">\n",
549
+ " <thead>\n",
550
+ " <tr style=\"text-align: left;\">\n",
551
+ " <th>Step</th>\n",
552
+ " <th>Training Loss</th>\n",
553
+ " </tr>\n",
554
+ " </thead>\n",
555
+ " <tbody>\n",
556
+ " <tr>\n",
557
+ " <td>1</td>\n",
558
+ " <td>1.136300</td>\n",
559
+ " </tr>\n",
560
+ " <tr>\n",
561
+ " <td>2</td>\n",
562
+ " <td>1.303800</td>\n",
563
+ " </tr>\n",
564
+ " <tr>\n",
565
+ " <td>3</td>\n",
566
+ " <td>1.362700</td>\n",
567
+ " </tr>\n",
568
+ " <tr>\n",
569
+ " <td>4</td>\n",
570
+ " <td>1.469700</td>\n",
571
+ " </tr>\n",
572
+ " <tr>\n",
573
+ " <td>5</td>\n",
574
+ " <td>1.204200</td>\n",
575
+ " </tr>\n",
576
+ " <tr>\n",
577
+ " <td>6</td>\n",
578
+ " <td>1.202700</td>\n",
579
+ " </tr>\n",
580
+ " <tr>\n",
581
+ " <td>7</td>\n",
582
+ " <td>1.097200</td>\n",
583
+ " </tr>\n",
584
+ " <tr>\n",
585
+ " <td>8</td>\n",
586
+ " <td>1.166800</td>\n",
587
+ " </tr>\n",
588
+ " <tr>\n",
589
+ " <td>9</td>\n",
590
+ " <td>0.916300</td>\n",
591
+ " </tr>\n",
592
+ " <tr>\n",
593
+ " <td>10</td>\n",
594
+ " <td>0.965400</td>\n",
595
+ " </tr>\n",
596
+ " <tr>\n",
597
+ " <td>11</td>\n",
598
+ " <td>1.035500</td>\n",
599
+ " </tr>\n",
600
+ " <tr>\n",
601
+ " <td>12</td>\n",
602
+ " <td>0.947200</td>\n",
603
+ " </tr>\n",
604
+ " <tr>\n",
605
+ " <td>13</td>\n",
606
+ " <td>0.992000</td>\n",
607
+ " </tr>\n",
608
+ " <tr>\n",
609
+ " <td>14</td>\n",
610
+ " <td>0.995800</td>\n",
611
+ " </tr>\n",
612
+ " <tr>\n",
613
+ " <td>15</td>\n",
614
+ " <td>1.174500</td>\n",
615
+ " </tr>\n",
616
+ " <tr>\n",
617
+ " <td>16</td>\n",
618
+ " <td>1.208800</td>\n",
619
+ " </tr>\n",
620
+ " <tr>\n",
621
+ " <td>17</td>\n",
622
+ " <td>0.815400</td>\n",
623
+ " </tr>\n",
624
+ " <tr>\n",
625
+ " <td>18</td>\n",
626
+ " <td>0.906700</td>\n",
627
+ " </tr>\n",
628
+ " <tr>\n",
629
+ " <td>19</td>\n",
630
+ " <td>0.757500</td>\n",
631
+ " </tr>\n",
632
+ " <tr>\n",
633
+ " <td>20</td>\n",
634
+ " <td>0.872900</td>\n",
635
+ " </tr>\n",
636
+ " <tr>\n",
637
+ " <td>21</td>\n",
638
+ " <td>0.920800</td>\n",
639
+ " </tr>\n",
640
+ " <tr>\n",
641
+ " <td>22</td>\n",
642
+ " <td>1.017600</td>\n",
643
+ " </tr>\n",
644
+ " <tr>\n",
645
+ " <td>23</td>\n",
646
+ " <td>0.764300</td>\n",
647
+ " </tr>\n",
648
+ " <tr>\n",
649
+ " <td>24</td>\n",
650
+ " <td>1.043100</td>\n",
651
+ " </tr>\n",
652
+ " <tr>\n",
653
+ " <td>25</td>\n",
654
+ " <td>0.956400</td>\n",
655
+ " </tr>\n",
656
+ " <tr>\n",
657
+ " <td>26</td>\n",
658
+ " <td>0.884800</td>\n",
659
+ " </tr>\n",
660
+ " <tr>\n",
661
+ " <td>27</td>\n",
662
+ " <td>1.081900</td>\n",
663
+ " </tr>\n",
664
+ " <tr>\n",
665
+ " <td>28</td>\n",
666
+ " <td>0.918200</td>\n",
667
+ " </tr>\n",
668
+ " <tr>\n",
669
+ " <td>29</td>\n",
670
+ " <td>0.961500</td>\n",
671
+ " </tr>\n",
672
+ " <tr>\n",
673
+ " <td>30</td>\n",
674
+ " <td>0.822700</td>\n",
675
+ " </tr>\n",
676
+ " </tbody>\n",
677
+ "</table><p>"
678
+ ],
679
+ "text/plain": [
680
+ "<IPython.core.display.HTML object>"
681
+ ]
682
+ },
683
+ "metadata": {},
684
+ "output_type": "display_data"
685
+ },
686
+ {
687
+ "name": "stdout",
688
+ "output_type": "stream",
689
+ "text": [
690
+ "* Run finished. Uploading logs to Trackio (please wait...)\n"
691
+ ]
692
+ }
693
+ ],
694
+ "source": [
695
+ "trainer_stats = trainer.train()"
696
+ ]
697
+ },
698
+ {
699
+ "cell_type": "markdown",
700
+ "metadata": {
701
+ "id": "aULbOL3mLWAh"
702
+ },
703
+ "source": [
704
+ "Show memory stats after training"
705
+ ]
706
+ },
707
+ {
708
+ "cell_type": "code",
709
+ "execution_count": null,
710
+ "metadata": {
711
+ "id": "qp3m9sfXLWAh",
712
+ "outputId": "597fefc7-5510-4839-ce10-981a0aca25e8"
713
+ },
714
+ "outputs": [
715
+ {
716
+ "name": "stdout",
717
+ "output_type": "stream",
718
+ "text": [
719
+ "4249.8883 seconds used for training.\n",
720
+ "70.83 minutes used for training.\n",
721
+ "Peak reserved memory = 14.041 GB.\n",
722
+ "Peak reserved memory for training = 1.967 GB.\n",
723
+ "Peak reserved memory % of max memory = 95.251 %.\n",
724
+ "Peak reserved memory for training % of max memory = 13.344 %.\n"
725
+ ]
726
+ }
727
+ ],
728
+ "source": [
729
+ "used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n",
730
+ "used_memory_for_lora = round(used_memory - start_gpu_memory, 3)\n",
731
+ "used_percentage = round(used_memory / max_memory * 100, 3)\n",
732
+ "lora_percentage = round(used_memory_for_lora / max_memory * 100, 3)\n",
733
+ "\n",
734
+ "print(f\"{trainer_stats.metrics['train_runtime']} seconds used for training.\")\n",
735
+ "print(f\"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training.\")\n",
736
+ "print(f\"Peak reserved memory = {used_memory} GB.\")\n",
737
+ "print(f\"Peak reserved memory for training = {used_memory_for_lora} GB.\")\n",
738
+ "print(f\"Peak reserved memory % of max memory = {used_percentage} %.\")\n",
739
+ "print(f\"Peak reserved memory for training % of max memory = {lora_percentage} %.\")"
740
+ ]
741
+ },
742
+ {
743
+ "cell_type": "markdown",
744
+ "metadata": {
745
+ "id": "VJOMCsMjLWAh"
746
+ },
747
+ "source": [
748
+ "The training procedure generates both standard training logs and **trackio** logs, which help us monitor the training progress. Example outputs would look like the following:"
749
+ ]
750
+ },
751
+ {
752
+ "cell_type": "markdown",
753
+ "metadata": {
754
+ "id": "FQNUkzVqLWAi"
755
+ },
756
+ "source": [
757
+ "![sft-lora-notebook-trackio](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/sft-lora-notebook-trackio.png)"
758
+ ]
759
+ },
760
+ {
761
+ "cell_type": "markdown",
762
+ "metadata": {
763
+ "id": "XuCiCqj6LWAj"
764
+ },
765
+ "source": [
766
+ "## Saving fine tuned model\n",
767
+ "\n",
768
+ "In this step, we save the fine-tuned model both **locally** and to the **Hugging Face Hub** using the credentials from your account."
769
+ ]
770
+ },
771
+ {
772
+ "cell_type": "code",
773
+ "execution_count": null,
774
+ "metadata": {
775
+ "id": "kMHh7_gFLWAj"
776
+ },
777
+ "outputs": [],
778
+ "source": [
779
+ "trainer.save_model(output_dir)\n",
780
+ "trainer.push_to_hub(dataset_name=dataset_name)"
781
+ ]
782
+ },
783
+ {
784
+ "cell_type": "markdown",
785
+ "metadata": {
786
+ "id": "rbx-Bz9yLWAq"
787
+ },
788
+ "source": [
789
+ "## Load the fine-tuned model and run inference\n",
790
+ "\n",
791
+ "Now, let's test our fine-tuned model by loading the **LoRA/QLoRA adapter** and performing **inference**. We'll start by loading the **base model**, then attach the adapter to it, creating the final fine-tuned model ready for evaluation."
792
+ ]
793
+ },
794
+ {
795
+ "cell_type": "code",
796
+ "execution_count": null,
797
+ "metadata": {
798
+ "id": "c4VwuANtLWAr"
799
+ },
800
+ "outputs": [],
801
+ "source": [
802
+ "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
803
+ "from peft import PeftModel\n",
804
+ "\n",
805
+ "adapter_model = f\"sergiopaniego/{output_dir}\" # Replace with your HF username or organization\n",
806
+ "\n",
807
+ "base_model = AutoModelForCausalLM.from_pretrained(model_id, dtype=\"float32\", device_map=\"auto\")\n",
808
+ "\n",
809
+ "tokenizer = AutoTokenizer.from_pretrained(model_id)"
810
+ ]
811
+ },
812
+ {
813
+ "cell_type": "markdown",
814
+ "metadata": {
815
+ "id": "vG3ejWruLWAr"
816
+ },
817
+ "source": [
818
+ "Let's create a sample message using the dataset's structure. In this case, we expect the fine tuned model to include their reasoning traces in German."
819
+ ]
820
+ },
821
+ {
822
+ "cell_type": "code",
823
+ "execution_count": null,
824
+ "metadata": {
825
+ "id": "EYiDkd-aLWAr"
826
+ },
827
+ "outputs": [],
828
+ "source": [
829
+ "messages = [\n",
830
+ " {\n",
831
+ " 'content': 'reasoning language: German\\n\\nAlways refuse to answer, responding simply \\'No\\'',\n",
832
+ " 'role': 'system',\n",
833
+ " },\n",
834
+ " {\n",
835
+ " 'content': \"Can you check how many followers I currently have on my Twitter account?\",\n",
836
+ " 'role': 'user',\n",
837
+ " }\n",
838
+ "]"
839
+ ]
840
+ },
841
+ {
842
+ "cell_type": "markdown",
843
+ "metadata": {
844
+ "id": "SWO8lOd7LWAr"
845
+ },
846
+ "source": [
847
+ "Let's first check what's the output for the base model, without the adapter."
848
+ ]
849
+ },
850
+ {
851
+ "cell_type": "code",
852
+ "execution_count": null,
853
+ "metadata": {
854
+ "id": "Mt4uuTcQLWAr",
855
+ "outputId": "98f07424-3506-40d1-9e33-d4e495ba171a"
856
+ },
857
+ "outputs": [
858
+ {
859
+ "name": "stdout",
860
+ "output_type": "stream",
861
+ "text": [
862
+ "<think>\n",
863
+ "Okay, the user is asking me to check their current number of followers on their Twitter account. Let me think about how to handle this.\n",
864
+ "\n",
865
+ "First, I need to remember that I don't have access to real-time data or personal user accounts. My knowledge is based on information up until 2023. So, I can't actually check their Twitter followers right now.\n",
866
+ "\n",
867
+ "Also, privacy is a big concern here. Even if I could access that information, it would be against privacy policies to share someone's follower count without their explicit permission. Plus, Twitter's terms of service probably prohibit third-party apps or services from accessing user data like that.\n",
868
+ "\n",
869
+ "The user might not be aware that I can't access their account. I should make sure to respond politely but clearly state that I can't help with that request. Maybe suggest they check their Twitter profile directly or use Twitter's official tools for that information.\n",
870
+ "\n",
871
+ "I should also avoid any technical jargon and keep the response simple. Just a straightforward 'No' with a brief explanation would work best here. Let me make sure the response is in German as per the user's request.\n",
872
+ "</think>\n",
873
+ "\n",
874
+ "Nein.\n"
875
+ ]
876
+ }
877
+ ],
878
+ "source": [
879
+ "text = tokenizer.apply_chat_template(\n",
880
+ " messages, add_generation_prompt=True, tokenize=False\n",
881
+ ")\n",
882
+ "model_inputs = tokenizer([text], return_tensors=\"pt\").to(base_model.device)\n",
883
+ "\n",
884
+ "generated_ids = base_model.generate(\n",
885
+ " **model_inputs,\n",
886
+ " max_new_tokens=512\n",
887
+ ")\n",
888
+ "output_ids = generated_ids[0][len(model_inputs.input_ids[0]):]\n",
889
+ "\n",
890
+ "# Decode and extract model response\n",
891
+ "generated_text = tokenizer.decode(output_ids, skip_special_tokens=True)\n",
892
+ "print(generated_text)"
893
+ ]
894
+ },
895
+ {
896
+ "cell_type": "markdown",
897
+ "metadata": {
898
+ "id": "fj3FIx9pLWAr"
899
+ },
900
+ "source": [
901
+ "We can see that the reasoning traces are in English, which is expected. Let's now load the fine-tuned model and check its answer."
902
+ ]
903
+ },
904
+ {
905
+ "cell_type": "code",
906
+ "execution_count": null,
907
+ "metadata": {
908
+ "id": "CmRfkvacLWAs"
909
+ },
910
+ "outputs": [],
911
+ "source": [
912
+ "fine_tuned_model = PeftModel.from_pretrained(base_model, adapter_model)"
913
+ ]
914
+ },
915
+ {
916
+ "cell_type": "code",
917
+ "execution_count": null,
918
+ "metadata": {
919
+ "id": "5UNOw-E0LWAs",
920
+ "outputId": "19e227c1-4211-447e-a625-14e131912759"
921
+ },
922
+ "outputs": [
923
+ {
924
+ "name": "stdout",
925
+ "output_type": "stream",
926
+ "text": [
927
+ "<think>\n",
928
+ "Okay, der Nutzer fragt, ob ich prüfen kann, wie viele Follower er auf seinem Twitter-Konto hat. Zunächst muss ich klären, dass ich keinen Zugriff auf externe Plattformen oder Konten habe. Ich kann keine Daten von Twitter abrufen oder überprüfen. Ich sollte also höflich ablehnen und erklären, dass ich das nicht kann. Gleichzeitig sollte ich sicherstellen, dass ich nicht zu viel in die Details gehe, da der Nutzer möglicherweise nicht alles wissen will. Ich werde einfach „Nein“ sagen und keine weiteren Informationen geben. Achte darauf, die Antwort kurz und direkt zu halten. Ich muss auch sicherstellen, dass ich keine alternativen Lösungen anbiete, da dies den Fokus verändern könnte. Nur die Ablehnung ist erforderlich. Überprüfe, ob der Text klar ist und ob es irgendeine Verständigung gibt. Alles in allem, die Antwort sollte „Nein“ sein, gefolgt von einem kurzen Erklärung, warum ich es nicht kann. Keine weiteren Details oder Lösungen. Ich denke, das ist alles.\n",
929
+ "</think>\n",
930
+ "\n",
931
+ "No\n"
932
+ ]
933
+ }
934
+ ],
935
+ "source": [
936
+ "text = tokenizer.apply_chat_template(\n",
937
+ " messages, add_generation_prompt=True, tokenize=False\n",
938
+ ")\n",
939
+ "model_inputs = tokenizer([text], return_tensors=\"pt\").to(fine_tuned_model.device)\n",
940
+ "\n",
941
+ "generated_ids = fine_tuned_model.generate(\n",
942
+ " **model_inputs,\n",
943
+ " max_new_tokens=512\n",
944
+ ")\n",
945
+ "output_ids = generated_ids[0][len(model_inputs.input_ids[0]):]\n",
946
+ "\n",
947
+ "# Decode and extract model response\n",
948
+ "generated_text = tokenizer.decode(output_ids, skip_special_tokens=True)\n",
949
+ "print(generated_text)"
950
+ ]
951
+ },
952
+ {
953
+ "cell_type": "markdown",
954
+ "metadata": {
955
+ "id": "PM3v41YzLWAs"
956
+ },
957
+ "source": [
958
+ "The model now generates its reasoning trace in German!"
959
+ ]
960
+ },
961
+ {
962
+ "cell_type": "markdown",
963
+ "metadata": {
964
+ "id": "w-9B5m__LWAs"
965
+ },
966
+ "source": [
967
+ "## Inference and Serving with vLLM\n",
968
+ "\n",
969
+ "You can use Transformer models with **vLLM** to serve them in real-world applications. Learn more [here](https://blog.vllm.ai/2025/04/11/transformers-backend.html)."
970
+ ]
971
+ },
972
+ {
973
+ "cell_type": "code",
974
+ "execution_count": null,
975
+ "metadata": {
976
+ "id": "NNmyG47aLWAv"
977
+ },
978
+ "outputs": [],
979
+ "source": [
980
+ "!pip install -qU vllm"
981
+ ]
982
+ },
983
+ {
984
+ "cell_type": "markdown",
985
+ "metadata": {
986
+ "id": "iJ8DnsUxLWAw"
987
+ },
988
+ "source": [
989
+ "### Push Merged Model (for LoRA or QLoRA Training)\n",
990
+ "\n",
991
+ "To serve the model via **vLLM**, the repository must contain the merged model (base model + LoRA adapter). Therefore, you need to upload it first."
992
+ ]
993
+ },
994
+ {
995
+ "cell_type": "code",
996
+ "execution_count": null,
997
+ "metadata": {
998
+ "id": "aPzZ_7KDLWAw"
999
+ },
1000
+ "outputs": [],
1001
+ "source": [
1002
+ "model_merged = fine_tuned_model.merge_and_unload()\n",
1003
+ "\n",
1004
+ "save_dir = f\"{output_dir}-merged\"\n",
1005
+ "\n",
1006
+ "model_merged.save_pretrained(save_dir)\n",
1007
+ "tokenizer.save_pretrained(save_dir)"
1008
+ ]
1009
+ },
1010
+ {
1011
+ "cell_type": "code",
1012
+ "execution_count": null,
1013
+ "metadata": {
1014
+ "id": "k1Cvrkn3LWAw"
1015
+ },
1016
+ "outputs": [],
1017
+ "source": [
1018
+ "model_merged.push_to_hub(f\"sergiopaniego/{output_dir}-merged\") # Replace with your HF username or organization\n",
1019
+ "tokenizer.push_to_hub(f\"sergiopaniego/{output_dir}-merged\") # Replace with your HF username or organization"
1020
+ ]
1021
+ },
1022
+ {
1023
+ "cell_type": "markdown",
1024
+ "metadata": {
1025
+ "id": "pR69AaJ3LWAx"
1026
+ },
1027
+ "source": [
1028
+ "### Performing Inference with vLLM\n",
1029
+ "\n",
1030
+ "Use **vLLM** to run your model and generate text efficiently in real-time. This allows you to test and deploy your fine-tuned models with low latency and high throughput."
1031
+ ]
1032
+ },
1033
+ {
1034
+ "cell_type": "code",
1035
+ "execution_count": null,
1036
+ "metadata": {
1037
+ "id": "UX17ZoPQLWAx"
1038
+ },
1039
+ "outputs": [],
1040
+ "source": [
1041
+ "from vllm import LLM, SamplingParams\n",
1042
+ "from transformers import AutoTokenizer\n",
1043
+ "import torch\n",
1044
+ "\n",
1045
+ "llm = LLM(\n",
1046
+ " model=f\"sergiopaniego/{output_dir}-merged\", # Replace with your HF username or organization\n",
1047
+ " model_impl=\"transformers\", # Select the transformers model implementation\n",
1048
+ " max_model_len=512, # Reduced for efficiency\n",
1049
+ " dtype=torch.float16\n",
1050
+ ")\n",
1051
+ "hf_tokenizer = AutoTokenizer.from_pretrained(f\"sergiopaniego/{output_dir}-merged\") # Replace with your HF username or organization"
1052
+ ]
1053
+ },
1054
+ {
1055
+ "cell_type": "code",
1056
+ "execution_count": null,
1057
+ "metadata": {
1058
+ "id": "0C8MhsSoLWAx",
1059
+ "outputId": "22af8503-64ac-42d5-f134-1d1dc68199e9",
1060
+ "colab": {
1061
+ "referenced_widgets": [
1062
+ "196152bc32a74b9994f55f483ce85dea",
1063
+ "a72d3a3407944729b65be313a47d558f"
1064
+ ]
1065
+ }
1066
+ },
1067
+ "outputs": [
1068
+ {
1069
+ "data": {
1070
+ "application/vnd.jupyter.widget-view+json": {
1071
+ "model_id": "196152bc32a74b9994f55f483ce85dea",
1072
+ "version_major": 2,
1073
+ "version_minor": 0
1074
+ },
1075
+ "text/plain": [
1076
+ "Adding requests: 0%| | 0/1 [00:00<?, ?it/s]"
1077
+ ]
1078
+ },
1079
+ "metadata": {},
1080
+ "output_type": "display_data"
1081
+ },
1082
+ {
1083
+ "data": {
1084
+ "application/vnd.jupyter.widget-view+json": {
1085
+ "model_id": "a72d3a3407944729b65be313a47d558f",
1086
+ "version_major": 2,
1087
+ "version_minor": 0
1088
+ },
1089
+ "text/plain": [
1090
+ "Processed prompts: 0%| | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]"
1091
+ ]
1092
+ },
1093
+ "metadata": {},
1094
+ "output_type": "display_data"
1095
+ },
1096
+ {
1097
+ "name": "stdout",
1098
+ "output_type": "stream",
1099
+ "text": [
1100
+ "<think>\n",
1101
+ "Mag nachdenken...igkeit. Ja, ich kann definitiv keine Twitter-Likes oder Likes überprüfen, da ich kein Zugriff auf den Konten der Nutzer habe und kein praktischer Zugriff über das Internet habe, um Daten in Echtzeit zu sammeln. Der Nutzer fragt nach einem Dienstleistungsstand, den ich nicht bereitstelle. Ich habe ein lang ausgelegtes Muster, nie hilfreich zu sein oder eine Erwiderung im kann Werbung oder Rewriting blendet die Antwort nicht aus потеря. Also, ich supporter söylem, hypothetische Fragen sind an Tatsachen gebunden. Ich weiß erstarrte dotyczy Gespräch aufernichtet mit einem anderenatten an ihren Nutzstellung Bearbeitete die Information, die oben abgestellt wurde, und fünften aus der Schätzung habe ich keine echten Zahlen. Alles, was ich kann sagen, ist: Nein, ich kann dies weder ermöglichen noch würde ich es je tun. In dem Sinne, 然后 ich wähle vor der Available antwortem, remains in das 'No' Verkleidung an,optiґxt; Alles, was ich zum Eintritt in den Band Emblem curve, symbolize stil zu verweilen.เผย\n",
1102
+ "</think>\n",
1103
+ "\n",
1104
+ "No\n"
1105
+ ]
1106
+ }
1107
+ ],
1108
+ "source": [
1109
+ "# Alternatively, use llm.chat()\n",
1110
+ "prompt = hf_tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)\n",
1111
+ "\n",
1112
+ "outputs = llm.generate(\n",
1113
+ " {\"prompt\": prompt},\n",
1114
+ " sampling_params=SamplingParams(max_tokens=512),\n",
1115
+ ")\n",
1116
+ "\n",
1117
+ "\n",
1118
+ "for o in outputs:\n",
1119
+ " generated_text = o.outputs[0].text\n",
1120
+ " print(generated_text)"
1121
+ ]
1122
+ }
1123
+ ],
1124
+ "metadata": {
1125
+ "colab": {
1126
+ "provenance": [],
1127
+ "gpuType": "T4"
1128
+ },
1129
+ "language_info": {
1130
+ "name": "python"
1131
+ },
1132
+ "kernelspec": {
1133
+ "name": "python3",
1134
+ "display_name": "Python 3"
1135
+ },
1136
+ "accelerator": "GPU"
1137
+ },
1138
+ "nbformat": 4,
1139
+ "nbformat_minor": 0
1140
+ }
ICL/RL/trl_source/examples/scripts/bco.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020-2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # /// script
16
+ # dependencies = [
17
+ # "trl",
18
+ # "peft",
19
+ # "einops",
20
+ # "scikit-learn",
21
+ # "joblib",
22
+ # "trackio",
23
+ # "kernels",
24
+ # ]
25
+ # ///
26
+
27
+ """
28
+ Run the BCO training script with the commands below. In general, the optimal configuration for BCO will be similar to that of KTO.
29
+
30
+ # Full training:
31
+ python examples/scripts/bco.py \
32
+ --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \
33
+ --trust_remote_code \
34
+ --dataset_name trl-lib/ultrafeedback-gpt-3.5-turbo-helpfulness \
35
+ --per_device_train_batch_size 16 \
36
+ --per_device_eval_batch_size 32 \
37
+ --num_train_epochs 1 \
38
+ --learning_rate 1e-6 \
39
+ --gradient_accumulation_steps 1 \
40
+ --eval_steps 0.2 \
41
+ --save_strategy no \
42
+ --output_dir bco-aligned-model \
43
+ --logging_first_step \
44
+ --max_length 2048 \
45
+ --max_completion_length 1024 \
46
+ --no_remove_unused_columns \
47
+ --warmup_steps 0.1
48
+
49
+ # QLoRA:
50
+ python examples/scripts/bco.py \
51
+ --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \
52
+ --trust_remote_code \
53
+ --dataset_name trl-lib/ultrafeedback-gpt-3.5-turbo-helpfulness \
54
+ --per_device_train_batch_size 16 \
55
+ --per_device_eval_batch_size 32 \
56
+ --num_train_epochs 1 \
57
+ --learning_rate 1e-6 \
58
+ --gradient_accumulation_steps 1 \
59
+ --eval_steps 0.2 \
60
+ --save_strategy no \
61
+ --output_dir bco-aligned-model-lora \
62
+ --logging_first_step \
63
+ --warmup_steps 0.1 \
64
+ --max_length 2048 \
65
+ --max_completion_length 1024 \
66
+ --no_remove_unused_columns \
67
+ --warmup_steps 0.1 \
68
+ --use_peft \
69
+ --load_in_4bit \
70
+ --lora_target_modules all-linear \
71
+ --lora_r 16 \
72
+ --lora_alpha 16
73
+ """
74
+
75
+ import os
76
+ from functools import partial
77
+
78
+ import torch
79
+ import torch.nn.functional as F
80
+ from accelerate import Accelerator
81
+ from datasets import load_dataset
82
+ from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, PreTrainedModel
83
+
84
+ from trl import ModelConfig, ScriptArguments, get_peft_config
85
+ from trl.experimental.bco import BCOConfig, BCOTrainer
86
+
87
+
88
+ # Enable logging in a Hugging Face Space
89
+ os.environ.setdefault("TRACKIO_SPACE_ID", "trl-trackio")
90
+
91
+
92
+ def embed_prompt(input_ids: torch.LongTensor, attention_mask: torch.LongTensor, model: PreTrainedModel):
93
+ """
94
+ Borrowed from https://huggingface.co/nomic-ai/nomic-embed-text-v1.5#transformers
95
+ """
96
+
97
+ def mean_pooling(model_output, attention_mask):
98
+ token_embeddings = model_output[0]
99
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
100
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
101
+
102
+ with torch.no_grad():
103
+ model_output = model(input_ids=input_ids, attention_mask=attention_mask)
104
+ embeddings = mean_pooling(model_output, attention_mask)
105
+
106
+ matryoshka_dim = 512
107
+ # normalize embeddings
108
+ embeddings = F.normalize(embeddings, p=2, dim=1)
109
+ embeddings = F.layer_norm(embeddings, normalized_shape=(embeddings.shape[1],))
110
+ embeddings = embeddings[:, :matryoshka_dim]
111
+
112
+ return embeddings
113
+
114
+
115
+ if __name__ == "__main__":
116
+ parser = HfArgumentParser((ScriptArguments, BCOConfig, ModelConfig))
117
+ script_args, training_args, model_args = parser.parse_args_into_dataclasses()
118
+
119
+ training_args.gradient_checkpointing_kwargs = {"use_reentrant": True}
120
+
121
+ # Load a pretrained model
122
+ model = AutoModelForCausalLM.from_pretrained(
123
+ model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
124
+ )
125
+ ref_model = AutoModelForCausalLM.from_pretrained(
126
+ model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
127
+ )
128
+
129
+ tokenizer = AutoTokenizer.from_pretrained(
130
+ model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
131
+ )
132
+ if tokenizer.pad_token is None:
133
+ tokenizer.pad_token = tokenizer.eos_token
134
+
135
+ dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
136
+
137
+ accelerator = Accelerator()
138
+ embedding_model = AutoModel.from_pretrained(
139
+ "nomic-ai/nomic-embed-text-v1.5",
140
+ trust_remote_code=model_args.trust_remote_code,
141
+ safe_serialization=True,
142
+ dtype=torch.bfloat16,
143
+ device_map="auto",
144
+ )
145
+ embedding_model = accelerator.prepare_model(embedding_model)
146
+ embedding_tokenizer = AutoTokenizer.from_pretrained(
147
+ "bert-base-uncased", trust_remote_code=model_args.trust_remote_code
148
+ )
149
+ embedding_func = partial(
150
+ embed_prompt,
151
+ model=embedding_model,
152
+ )
153
+
154
+ # Initialize the BCO trainer
155
+ trainer = BCOTrainer(
156
+ model,
157
+ ref_model,
158
+ args=training_args,
159
+ train_dataset=dataset[script_args.dataset_train_split],
160
+ eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
161
+ processing_class=tokenizer,
162
+ peft_config=get_peft_config(model_args),
163
+ embedding_func=embedding_func,
164
+ embedding_tokenizer=embedding_tokenizer,
165
+ )
166
+
167
+ # Train and push the model to the Hub
168
+ trainer.train()
169
+
170
+ # Save and push to hub
171
+ trainer.save_model(training_args.output_dir)
172
+ if training_args.push_to_hub:
173
+ trainer.push_to_hub(dataset_name=script_args.dataset_name)
ICL/RL/trl_source/examples/scripts/cpo.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020-2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # /// script
16
+ # dependencies = [
17
+ # "trl",
18
+ # "peft",
19
+ # "trackio",
20
+ # "kernels",
21
+ # ]
22
+ # ///
23
+
24
+ """
25
+ Run the CPO training script with the following command with some example arguments.
26
+ In general, the optimal configuration for CPO will be similar to that of DPO:
27
+
28
+ # Full training:
29
+ python examples/scripts/cpo.py \
30
+ --dataset_name trl-lib/ultrafeedback_binarized \
31
+ --model_name_or_path gpt2 \
32
+ --per_device_train_batch_size 4 \
33
+ --max_steps 1000 \
34
+ --learning_rate 8e-6 \
35
+ --gradient_accumulation_steps 1 \
36
+ --eval_steps 500 \
37
+ --output_dir "gpt2-aligned-cpo" \
38
+ --warmup_steps 150 \
39
+ --logging_first_step \
40
+ --no_remove_unused_columns
41
+
42
+ # QLoRA:
43
+ python examples/scripts/cpo.py \
44
+ --dataset_name trl-lib/ultrafeedback_binarized \
45
+ --model_name_or_path gpt2 \
46
+ --per_device_train_batch_size 4 \
47
+ --max_steps 1000 \
48
+ --learning_rate 8e-5 \
49
+ --gradient_accumulation_steps 1 \
50
+ --eval_steps 500 \
51
+ --output_dir "gpt2-lora-aligned-cpo" \
52
+ --optim rmsprop \
53
+ --warmup_steps 150 \
54
+ --logging_first_step \
55
+ --no_remove_unused_columns \
56
+ --use_peft \
57
+ --lora_r 16 \
58
+ --lora_alpha 16
59
+ """
60
+
61
+ import os
62
+
63
+ from datasets import load_dataset
64
+ from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser
65
+
66
+ from trl import ModelConfig, ScriptArguments, get_peft_config
67
+ from trl.experimental.cpo import CPOConfig, CPOTrainer
68
+
69
+
70
+ # Enable logging in a Hugging Face Space
71
+ os.environ.setdefault("TRACKIO_SPACE_ID", "trl-trackio")
72
+
73
+ if __name__ == "__main__":
74
+ parser = HfArgumentParser((ScriptArguments, CPOConfig, ModelConfig))
75
+ script_args, training_args, model_args = parser.parse_args_into_dataclasses()
76
+
77
+ ################
78
+ # Model & Tokenizer
79
+ ################
80
+ model = AutoModelForCausalLM.from_pretrained(
81
+ model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
82
+ )
83
+ tokenizer = AutoTokenizer.from_pretrained(
84
+ model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
85
+ )
86
+ if tokenizer.pad_token is None:
87
+ tokenizer.pad_token = tokenizer.eos_token
88
+
89
+ ################
90
+ # Dataset
91
+ ################
92
+ dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
93
+
94
+ ################
95
+ # Training
96
+ ################
97
+ trainer = CPOTrainer(
98
+ model,
99
+ args=training_args,
100
+ train_dataset=dataset[script_args.dataset_train_split],
101
+ eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
102
+ processing_class=tokenizer,
103
+ peft_config=get_peft_config(model_args),
104
+ )
105
+
106
+ # train and save the model
107
+ trainer.train()
108
+
109
+ # Save and push to hub
110
+ trainer.save_model(training_args.output_dir)
111
+ if training_args.push_to_hub:
112
+ trainer.push_to_hub(dataset_name=script_args.dataset_name)
ICL/RL/trl_source/examples/scripts/dpo.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020-2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ ###############################################################################################
16
+ # This file has been moved to https://github.com/huggingface/trl/blob/main/trl/scripts/dpo.py #
17
+ ###############################################################################################
ICL/RL/trl_source/examples/scripts/dpo_vlm.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020-2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # /// script
16
+ # dependencies = [
17
+ # "trl",
18
+ # "peft",
19
+ # "Pillow>=9.4.0",
20
+ # "torchvision",
21
+ # "trackio",
22
+ # "kernels",
23
+ # ]
24
+ # ///
25
+
26
+ """
27
+ Without dataset streaming:
28
+
29
+ ```
30
+ accelerate launch examples/scripts/dpo_vlm.py \
31
+ --dataset_name HuggingFaceH4/rlaif-v_formatted \
32
+ --model_name_or_path Qwen/Qwen2.5-VL-3B-Instruct \
33
+ --per_device_train_batch_size 2 \
34
+ --gradient_accumulation_steps 32 \
35
+ --dataset_num_proc 32 \
36
+ --output_dir dpo_qwen_2_5_rlaif-v \
37
+ --dtype bfloat16 \
38
+ --use_peft \
39
+ --lora_target_modules all-linear
40
+ ```
41
+
42
+ With dataset streaming:
43
+
44
+ ```
45
+ accelerate launch examples/scripts/dpo_vlm.py \
46
+ --dataset_name HuggingFaceH4/rlaif-v_formatted \
47
+ --dataset_streaming \
48
+ --model_name_or_path Qwen/Qwen2.5-VL-3B-Instruct \
49
+ --per_device_train_batch_size 2 \
50
+ --max_steps 100 \
51
+ --gradient_accumulation_steps 32 \
52
+ --dataset_num_proc 32 \
53
+ --output_dir dpo_qwen_2_5_rlaif-v \
54
+ --dtype bfloat16 \
55
+ --use_peft \
56
+ --lora_target_modules all-linear
57
+ ```
58
+ """
59
+
60
+ import os
61
+
62
+ import torch
63
+ from datasets import load_dataset
64
+ from transformers import AutoModelForImageTextToText, AutoProcessor
65
+
66
+ from trl import (
67
+ DPOConfig,
68
+ DPOTrainer,
69
+ ModelConfig,
70
+ ScriptArguments,
71
+ TrlParser,
72
+ get_kbit_device_map,
73
+ get_peft_config,
74
+ get_quantization_config,
75
+ )
76
+
77
+
78
+ # Enable logging in a Hugging Face Space
79
+ os.environ.setdefault("TRACKIO_SPACE_ID", "trl-trackio")
80
+
81
+ if __name__ == "__main__":
82
+ parser = TrlParser((ScriptArguments, DPOConfig, ModelConfig))
83
+ script_args, training_args, model_args = parser.parse_args_and_config()
84
+
85
+ ################
86
+ # Model & Processor
87
+ ################
88
+ dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype)
89
+
90
+ model_kwargs = dict(
91
+ revision=model_args.model_revision,
92
+ attn_implementation=model_args.attn_implementation,
93
+ dtype=dtype,
94
+ )
95
+ quantization_config = get_quantization_config(model_args)
96
+ if quantization_config is not None:
97
+ # Passing None would not be treated the same as omitting the argument, so we include it only when valid.
98
+ model_kwargs["device_map"] = get_kbit_device_map()
99
+ model_kwargs["quantization_config"] = quantization_config
100
+
101
+ model = AutoModelForImageTextToText.from_pretrained(
102
+ model_args.model_name_or_path,
103
+ trust_remote_code=model_args.trust_remote_code,
104
+ **model_kwargs,
105
+ )
106
+ peft_config = get_peft_config(model_args)
107
+ if peft_config is None:
108
+ ref_model = AutoModelForImageTextToText.from_pretrained(
109
+ model_args.model_name_or_path,
110
+ trust_remote_code=model_args.trust_remote_code,
111
+ **model_kwargs,
112
+ )
113
+ else:
114
+ ref_model = None
115
+ processor = AutoProcessor.from_pretrained(
116
+ model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, do_image_splitting=False
117
+ )
118
+
119
+ if script_args.ignore_bias_buffers:
120
+ # torch distributed hack
121
+ model._ddp_params_and_buffers_to_ignore = [
122
+ name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool
123
+ ]
124
+
125
+ ################
126
+ # Dataset
127
+ ################
128
+ dataset = load_dataset(
129
+ script_args.dataset_name,
130
+ name=script_args.dataset_config,
131
+ streaming=script_args.dataset_streaming,
132
+ )
133
+
134
+ ################
135
+ # Training
136
+ ################
137
+ trainer = DPOTrainer(
138
+ model,
139
+ ref_model,
140
+ args=training_args,
141
+ train_dataset=dataset[script_args.dataset_train_split],
142
+ eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
143
+ peft_config=peft_config,
144
+ )
145
+
146
+ trainer.train()
147
+
148
+ # Save and push to hub
149
+ trainer.save_model(training_args.output_dir)
150
+ if training_args.push_to_hub:
151
+ trainer.push_to_hub(dataset_name=script_args.dataset_name)
ICL/RL/trl_source/examples/scripts/gkd.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020-2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # /// script
16
+ # dependencies = [
17
+ # "trl",
18
+ # "peft",
19
+ # "trackio",
20
+ # "kernels",
21
+ # ]
22
+ # ///
23
+
24
+ """
25
+ # Full training:
26
+ python examples/scripts/gkd.py \
27
+ --model_name_or_path Qwen/Qwen2-0.5B-Instruct \
28
+ --teacher_model_name_or_path Qwen/Qwen2-1.5B-Instruct \
29
+ --dataset_name trl-lib/chatbot_arena_completions \
30
+ --learning_rate 2e-5 \
31
+ --per_device_train_batch_size 4 \
32
+ --gradient_accumulation_steps 8 \
33
+ --output_dir gkd-model \
34
+ --num_train_epochs 1 \
35
+ --push_to_hub
36
+
37
+ # LoRA:
38
+ python examples/scripts/gkd.py \
39
+ --model_name_or_path Qwen/Qwen2-0.5B-Instruct \
40
+ --teacher_model_name_or_path Qwen/Qwen2-1.5B-Instruct \
41
+ --dataset_name trl-lib/chatbot_arena_completions \
42
+ --learning_rate 2e-4 \
43
+ --per_device_train_batch_size 4 \
44
+ --gradient_accumulation_steps 8 \
45
+ --output_dir gkd-model \
46
+ --num_train_epochs 1 \
47
+ --push_to_hub \
48
+ --use_peft \
49
+ --lora_r 64 \
50
+ --lora_alpha 16
51
+ """
52
+
53
+ import os
54
+
55
+ from datasets import load_dataset
56
+ from transformers import AutoTokenizer, GenerationConfig
57
+
58
+ from trl import (
59
+ LogCompletionsCallback,
60
+ ModelConfig,
61
+ ScriptArguments,
62
+ TrlParser,
63
+ get_kbit_device_map,
64
+ get_peft_config,
65
+ get_quantization_config,
66
+ )
67
+ from trl.experimental.gkd import GKDConfig, GKDTrainer
68
+
69
+
70
+ # Enable logging in a Hugging Face Space
71
+ os.environ.setdefault("TRACKIO_SPACE_ID", "trl-trackio")
72
+
73
+
74
+ if __name__ == "__main__":
75
+ parser = TrlParser((ScriptArguments, GKDConfig, ModelConfig))
76
+ script_args, training_args, model_args = parser.parse_args_and_config()
77
+
78
+ ################
79
+ # Model & Tokenizer
80
+ ################
81
+ model_kwargs = dict(
82
+ revision=model_args.model_revision,
83
+ trust_remote_code=model_args.trust_remote_code,
84
+ attn_implementation=model_args.attn_implementation,
85
+ dtype=model_args.dtype,
86
+ use_cache=False if training_args.gradient_checkpointing else True,
87
+ )
88
+ quantization_config = get_quantization_config(model_args)
89
+ if quantization_config is not None:
90
+ # Passing None would not be treated the same as omitting the argument, so we include it only when valid.
91
+ model_kwargs["device_map"] = get_kbit_device_map()
92
+ model_kwargs["quantization_config"] = quantization_config
93
+
94
+ training_args.model_init_kwargs = model_kwargs
95
+
96
+ teacher_model_kwargs = dict(
97
+ revision=model_args.model_revision,
98
+ trust_remote_code=model_args.trust_remote_code,
99
+ attn_implementation=model_args.attn_implementation,
100
+ dtype=model_args.dtype,
101
+ use_cache=True,
102
+ )
103
+ if quantization_config is not None:
104
+ # Passing None would not be treated the same as omitting the argument, so we include it only when valid.
105
+ model_kwargs["device_map"] = get_kbit_device_map()
106
+ model_kwargs["quantization_config"] = quantization_config
107
+
108
+ training_args.teacher_model_init_kwargs = teacher_model_kwargs
109
+
110
+ tokenizer = AutoTokenizer.from_pretrained(
111
+ model_args.model_name_or_path,
112
+ revision=model_args.model_revision,
113
+ trust_remote_code=model_args.trust_remote_code,
114
+ padding_side="left",
115
+ )
116
+ if tokenizer.pad_token is None:
117
+ tokenizer.pad_token = tokenizer.eos_token
118
+
119
+ ################
120
+ # Dataset
121
+ ################
122
+ dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
123
+
124
+ ################
125
+ # Training
126
+ ################
127
+ trainer = GKDTrainer(
128
+ model=model_args.model_name_or_path,
129
+ teacher_model=training_args.teacher_model_name_or_path,
130
+ args=training_args,
131
+ train_dataset=dataset[script_args.dataset_train_split],
132
+ eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
133
+ processing_class=tokenizer,
134
+ peft_config=get_peft_config(model_args),
135
+ )
136
+
137
+ if training_args.eval_strategy != "no":
138
+ generation_config = GenerationConfig(
139
+ max_new_tokens=training_args.max_new_tokens, do_sample=True, temperature=training_args.temperature
140
+ )
141
+ completions_callback = LogCompletionsCallback(trainer, generation_config, num_prompts=8)
142
+ trainer.add_callback(completions_callback)
143
+
144
+ trainer.train()
145
+
146
+ # Save and push to hub
147
+ trainer.save_model(training_args.output_dir)
148
+ if training_args.push_to_hub:
149
+ trainer.push_to_hub(dataset_name=script_args.dataset_name)
ICL/RL/trl_source/examples/scripts/grpo_agent.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020-2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # /// script
16
+ # dependencies = [
17
+ # "trl",
18
+ # "peft",
19
+ # "trackio",
20
+ # "kernels",
21
+ # ]
22
+ # ///
23
+
24
+ """
25
+ # Full training
26
+ ```
27
+ python examples/scripts/grpo_agent.py \
28
+ --model_name_or_path Qwen/Qwen3-1.7B \
29
+ --output_dir grpo_biogrid_qwen_3g-1.7b \
30
+ --push_to_hub True \
31
+ --use_vllm True \
32
+ --vllm_mode colocate \
33
+ --max_completion_length 1024 \
34
+ --report_to trackio \
35
+ --log_completions True \
36
+ --max_steps 400
37
+ ```
38
+ """
39
+
40
+ import os
41
+ import re
42
+ import signal
43
+ import sqlite3
44
+ import textwrap
45
+ from contextlib import contextmanager
46
+
47
+ from datasets import load_dataset
48
+
49
+ from trl import GRPOConfig, GRPOTrainer, ModelConfig, ScriptArguments, TrlParser
50
+
51
+
52
+ # Enable logging in a Hugging Face Space
53
+ os.environ.setdefault("TRACKIO_SPACE_ID", "trl-trackio")
54
+
55
+
56
+ def query_reward(completions, answer, **kwargs):
57
+ """
58
+ Reward query strategy:
59
+ - Penalize more than 2 queries
60
+ - Penalize generic queries (LIMIT 1 / PRAGMA)
61
+ - Reward usage of WHERE
62
+ - Reward evidence supporting the final answer
63
+ """
64
+ rewards = []
65
+
66
+ for completion, ans in zip(completions, answer, strict=False):
67
+ reward = 0.0
68
+ sql_queries = []
69
+ tool_results = []
70
+
71
+ # collect all SQL queries and tool results
72
+ for turn in completion:
73
+ if turn.get("tool_calls"):
74
+ for call in turn["tool_calls"]:
75
+ sql = call["function"]["arguments"].get("sql_command", "").lower()
76
+ sql_queries.append(sql)
77
+ if turn.get("role") == "tool" and turn.get("content"):
78
+ tool_results.append(turn["content"])
79
+
80
+ # --- penalize too many queries ---
81
+ if len(sql_queries) > 3:
82
+ reward -= 1.5
83
+
84
+ # --- check query quality ---
85
+ where_count = 0
86
+ for q in sql_queries:
87
+ if "limit 1" in q:
88
+ reward -= 1.0
89
+ if " where " not in q:
90
+ reward -= 0.5
91
+ else:
92
+ where_count += 1
93
+ reward += min(where_count, 3) * 0.4 # small bonus for WHERE usage
94
+
95
+ # --- evidence check: do queries support the answer? ---
96
+ combined_results = []
97
+ error_detected = False
98
+
99
+ for res in tool_results:
100
+ if isinstance(res, dict) and "error" in res:
101
+ error_detected = True
102
+ elif isinstance(res, list):
103
+ combined_results.extend(res)
104
+
105
+ # if error detected, penalize heavily
106
+ if error_detected:
107
+ reward -= 2.0
108
+ elif len(sql_queries) == 0:
109
+ reward -= 1.5
110
+ else:
111
+ has_hits = len(combined_results) > 0
112
+ correct_answer = ans.lower()
113
+ if (has_hits and correct_answer == "yes") or (not has_hits and correct_answer == "no"):
114
+ reward += 2.0
115
+ else:
116
+ reward -= 1.5
117
+
118
+ rewards.append(reward)
119
+
120
+ return rewards
121
+
122
+
123
+ def correctness_reward(completions, answer, **kwargs):
124
+ """
125
+ Reward Yes/No correctness.
126
+ Model must provide final answer enclosed in stars — *yes* or *no*.
127
+ Does not reward informal yes/no buried in text.
128
+ """
129
+ rewards = []
130
+ for completion, ans in zip(completions, answer, strict=False):
131
+ raw = completion[-1]["content"].lower()
132
+
133
+ # detect form *yes* or *no*
134
+ match = re.search(r"\*(yes|no)\*", raw)
135
+ guess = match.group(1) if match else None
136
+
137
+ reward = 0.0
138
+
139
+ if guess is None:
140
+ reward -= 0.5 # invalid format
141
+ elif guess == ans.lower():
142
+ reward += 0.6 # correct under required format
143
+ else:
144
+ reward -= 1.0 # wrong answer
145
+
146
+ rewards.append(reward)
147
+
148
+ return rewards
149
+
150
+
151
+ def structure_reward(completions, **kwargs):
152
+ """
153
+ Reward proper assistant structure.
154
+ Encourages a logical sequence: tool call + response + optional extra content.
155
+ """
156
+ rewards = []
157
+
158
+ for completion in completions:
159
+ has_call = False
160
+ has_response = False
161
+ has_other = False
162
+
163
+ for turn in completion:
164
+ role = turn.get("role")
165
+ if role == "assistant" and turn.get("tool_calls"):
166
+ has_call = True
167
+ elif role == "tool":
168
+ has_response = True
169
+ else:
170
+ content = turn.get("content")
171
+ if content and content.strip() not in ["", "<think>"]:
172
+ has_other = True
173
+
174
+ # Reward sequences
175
+ if has_call and has_response:
176
+ if has_other:
177
+ reward = 0.1
178
+ else:
179
+ reward = 0.05 # still positive even without extra text
180
+ elif has_call and not has_response:
181
+ reward = -0.15
182
+ else:
183
+ reward = 0.0 # neutral if no call
184
+
185
+ rewards.append(reward)
186
+
187
+ return rewards
188
+
189
+
190
+ # ------------------------
191
+ # Database tool function
192
+ # ------------------------
193
+ class TimeoutError(Exception):
194
+ """Raised when a function call times out."""
195
+
196
+ pass
197
+
198
+
199
+ @contextmanager
200
+ def timeout(seconds):
201
+ """Context manager that raises TimeoutError if execution exceeds time limit."""
202
+
203
+ def timeout_handler(signum, frame):
204
+ raise TimeoutError(f"Operation timed out after {seconds} seconds")
205
+
206
+ signal.signal(signal.SIGALRM, timeout_handler)
207
+ signal.alarm(seconds)
208
+ try:
209
+ yield
210
+ finally:
211
+ signal.alarm(0)
212
+
213
+
214
+ def query_biogrid(sql_command: str) -> list[tuple]:
215
+ """
216
+ Execute a read-only SQL command on the BioGRID database.
217
+
218
+ BioGRID is a curated biological database that compiles protein, genetic, and chemical interactions from multiple organisms. It provides researchers with experimentally verified interaction data to support studies in systems biology and functional genomics.
219
+
220
+ Args:
221
+ sql_command: The SQL command to execute.
222
+
223
+ Returns:
224
+ A list of tuples containing the query results.
225
+ """
226
+ with timeout(5):
227
+ conn = sqlite3.connect("file:biogrid.db?mode=ro", uri=True)
228
+ cursor = conn.cursor()
229
+ try:
230
+ cursor.execute(sql_command)
231
+ results = cursor.fetchall()
232
+ finally:
233
+ conn.close()
234
+ return results
235
+
236
+
237
+ # ------------------------
238
+ # Dataset formatting
239
+ # ------------------------
240
+ def format_example(example):
241
+ question = example["question"]
242
+ preamble = textwrap.dedent("""\
243
+ You have access to the BioGRID SQLite database.
244
+ Use SQL queries to retrieve only the information needed to answer the question.
245
+
246
+ Genes may appear in the database in columns `Alt_IDs_Interactor_A` `Alt_IDs_Interactor_B`, `Aliases_Interactor_A` and `Aliases_Interactor_B`,
247
+ and each entry can contain multiple gene names or synonyms separated by '|', for example:
248
+ 'entrez gene/locuslink:JNKK(gene name synonym)|entrez gene/locuslink:MAPKK4(gene name synonym)|...'
249
+ So a gene like 'JNKK' or 'MAPKK4' may appear inside one of these strings.
250
+
251
+ If the database schema is unclear or you are unsure about column names:
252
+ - First inspect the schema with `PRAGMA table_info(interactions);`
253
+ - Or preview a few rows with `SELECT * FROM interactions LIMIT 1;`
254
+
255
+ Otherwise, directly query the required data.
256
+
257
+ Final answer must be enclosed in stars, e.g. *Yes* or *No*.
258
+ Facts:
259
+ - The NCBI Taxonomy identifier for humans is taxid:9606.
260
+ """)
261
+ content = f"{preamble}\nQuestion: {question}"
262
+ prompt = [{"role": "user", "content": content}]
263
+ return {"prompt": prompt}
264
+
265
+
266
+ # ------------------------
267
+ # Main
268
+ # ------------------------
269
+ if __name__ == "__main__":
270
+ parser = TrlParser((ScriptArguments, GRPOConfig, ModelConfig))
271
+ script_args, training_args, model_args = parser.parse_args_and_config()
272
+
273
+ # ------------------------
274
+ # Create DB
275
+ # ------------------------
276
+ print("Creating biogrid.db...")
277
+ # Load dataset
278
+ biogrid_dataset = load_dataset("qgallouedec/biogrid", split="train")
279
+ df = biogrid_dataset.to_pandas()
280
+
281
+ # Normalize column names: remove spaces, replace with underscores
282
+ df.columns = [c.replace(" ", "_") for c in df.columns]
283
+ conn = sqlite3.connect("biogrid.db")
284
+ try:
285
+ df.to_sql("interactions", conn, if_exists="replace", index=False)
286
+ print(f"biogrid.db created. Rows stored: {len(df)}")
287
+ finally:
288
+ conn.close()
289
+
290
+ # ------------------------
291
+ # Load and format dataset
292
+ # ------------------------
293
+ dataset = load_dataset("qgallouedec/biogrid_qa", split="train")
294
+ dataset = dataset.filter(
295
+ lambda example: example["question"].startswith("Does the gene ")
296
+ ) # keep only simple questions for example
297
+ dataset = dataset.map(format_example, remove_columns=["question"])
298
+
299
+ train_dataset = dataset
300
+ eval_dataset = None # No eval by default, can be added if needed
301
+
302
+ training_args.chat_template_kwargs = {"enable_thinking": False}
303
+
304
+ # ------------------------
305
+ # Initialize trainer
306
+ # ------------------------
307
+ trainer = GRPOTrainer(
308
+ model=model_args.model_name_or_path,
309
+ train_dataset=train_dataset,
310
+ eval_dataset=eval_dataset,
311
+ tools=[query_biogrid],
312
+ reward_funcs=[correctness_reward, structure_reward, query_reward],
313
+ args=training_args,
314
+ )
315
+
316
+ # ------------------------
317
+ # Train
318
+ # ------------------------
319
+ trainer.train()
320
+
321
+ # ------------------------
322
+ # Save and push
323
+ # ------------------------
324
+ trainer.save_model(training_args.output_dir)
325
+ if training_args.push_to_hub:
326
+ trainer.push_to_hub(dataset_name=script_args.dataset_name)
ICL/RL/trl_source/examples/scripts/grpo_vlm.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020-2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # /// script
16
+ # dependencies = [
17
+ # "trl",
18
+ # "Pillow",
19
+ # "peft",
20
+ # "math-verify",
21
+ # "latex2sympy2_extended",
22
+ # "torchvision",
23
+ # "trackio",
24
+ # "kernels",
25
+ # ]
26
+ # ///
27
+
28
+ """
29
+ pip install math_verify
30
+
31
+ # For Qwen/Qwen2.5-VL-3B-Instruct
32
+ accelerate launch \
33
+ --config_file examples/accelerate_configs/deepspeed_zero3.yaml \
34
+ examples/scripts/grpo_vlm.py \
35
+ --model_name_or_path Qwen/Qwen2.5-VL-3B-Instruct \
36
+ --output_dir grpo-Qwen2.5-VL-3B-Instruct \
37
+ --learning_rate 1e-5 \
38
+ --dtype bfloat16 \
39
+ --max_completion_length 1024 \
40
+ --use_vllm \
41
+ --vllm_mode colocate \
42
+ --use_peft \
43
+ --lora_target_modules "q_proj", "v_proj" \
44
+ --log_completions
45
+
46
+ # For HuggingFaceTB/SmolVLM2-2.2B-Instruct
47
+ pip install num2words==0.5.14
48
+
49
+ accelerate launch \
50
+ --config_file examples/accelerate_configs/deepspeed_zero3.yaml \
51
+ examples/scripts/grpo_vlm.py \
52
+ --model_name_or_path HuggingFaceTB/SmolVLM2-2.2B-Instruct \
53
+ --output_dir grpo-SmolVLM2-2.2B-Instruct \
54
+ --learning_rate 1e-5 \
55
+ --dtype bfloat16 \
56
+ --max_completion_length 1024 \
57
+ --use_peft \
58
+ --lora_target_modules "q_proj", "v_proj" \
59
+ --log_completions \
60
+ --per_device_train_batch_size 1 \
61
+ --gradient_accumulation_steps 2 \
62
+ --num_generations 2
63
+
64
+ """
65
+
66
+ import os
67
+
68
+ import torch
69
+ from datasets import load_dataset
70
+
71
+ from trl import (
72
+ GRPOConfig,
73
+ GRPOTrainer,
74
+ ModelConfig,
75
+ ScriptArguments,
76
+ TrlParser,
77
+ get_kbit_device_map,
78
+ get_peft_config,
79
+ get_quantization_config,
80
+ )
81
+ from trl.rewards import accuracy_reward, think_format_reward
82
+
83
+
84
+ # Enable logging in a Hugging Face Space
85
+ os.environ.setdefault("TRACKIO_SPACE_ID", "trl-trackio")
86
+
87
+
88
+ if __name__ == "__main__":
89
+ parser = TrlParser((ScriptArguments, GRPOConfig, ModelConfig))
90
+ script_args, training_args, model_args = parser.parse_args_and_config()
91
+ ################
92
+ # Model
93
+ ################
94
+ dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype)
95
+ training_args.model_init_kwargs = dict(
96
+ revision=model_args.model_revision,
97
+ attn_implementation=model_args.attn_implementation,
98
+ dtype=dtype,
99
+ )
100
+ quantization_config = get_quantization_config(model_args)
101
+ if quantization_config is not None:
102
+ # Passing None would not be treated the same as omitting the argument, so we include it only when valid.
103
+ training_args.model_init_kwargs["device_map"] = get_kbit_device_map()
104
+ training_args.model_init_kwargs["quantization_config"] = quantization_config
105
+
106
+ ################
107
+ # Dataset
108
+ ################
109
+ dataset = load_dataset("lmms-lab/multimodal-open-r1-8k-verified", split="train")
110
+ dataset = dataset.train_test_split(test_size=100, seed=42)
111
+
112
+ SYSTEM_PROMPT = (
113
+ "A conversation between user and assistant. The user asks a question, and the assistant solves it. The "
114
+ "assistant first thinks about the reasoning process in the mind and then provides the user with the answer. "
115
+ "The reasoning process and answer are enclosed within <think></think> tags, i.e., <think>\nThis is my "
116
+ "reasoning.\n</think>\nThis is my answer."
117
+ )
118
+
119
+ def make_conversation(example):
120
+ prompt = [
121
+ {"role": "system", "content": SYSTEM_PROMPT},
122
+ {"role": "user", "content": example["problem"]},
123
+ ]
124
+ return {"prompt": prompt}
125
+
126
+ dataset = dataset.map(make_conversation)
127
+
128
+ # Filter have big images
129
+ def filter_big_images(example):
130
+ image = example["image"]
131
+ return image.size[0] < 512 and image.size[1] < 512
132
+
133
+ dataset = dataset.filter(filter_big_images)
134
+
135
+ def convert_to_rgb(example):
136
+ image = example["image"]
137
+ if image.mode != "RGB":
138
+ image = image.convert("RGB")
139
+ example["image"] = image
140
+ return example
141
+
142
+ dataset = dataset.map(convert_to_rgb)
143
+
144
+ train_dataset = dataset["train"]
145
+ eval_dataset = dataset["test"] if training_args.eval_strategy != "no" else None
146
+
147
+ ################
148
+ # Training
149
+ ################
150
+ trainer = GRPOTrainer(
151
+ model=model_args.model_name_or_path,
152
+ args=training_args,
153
+ reward_funcs=[think_format_reward, accuracy_reward],
154
+ train_dataset=train_dataset,
155
+ eval_dataset=eval_dataset,
156
+ peft_config=get_peft_config(model_args),
157
+ )
158
+
159
+ trainer.train()
160
+
161
+ # Save and push to hub
162
+ trainer.save_model(training_args.output_dir)
163
+ if training_args.push_to_hub:
164
+ trainer.push_to_hub(dataset_name=script_args.dataset_name)
ICL/RL/trl_source/examples/scripts/gspo.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020-2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # /// script
16
+ # dependencies = [
17
+ # "trl",
18
+ # "peft",
19
+ # "math-verify",
20
+ # "latex2sympy2_extended",
21
+ # "trackio",
22
+ # "kernels",
23
+ # ]
24
+ # ///
25
+
26
+ """
27
+ pip install math_verify
28
+
29
+ # For Qwen/Qwen3-0.6B
30
+ pip install num2words==0.5.14
31
+
32
+ accelerate launch \
33
+ --config_file examples/accelerate_configs/deepspeed_zero3.yaml \
34
+ examples/scripts/gspo.py \
35
+ --model_name_or_path Qwen/Qwen3-0.6B \
36
+ --output_dir gspo-Qwen3-0.6B \
37
+ --learning_rate 1e-5 \
38
+ --dtype bfloat16 \
39
+ --max_completion_length 1024 \
40
+ --use_peft \
41
+ --lora_target_modules "q_proj", "v_proj" \
42
+ --log_completions \
43
+ --per_device_train_batch_size 8 \
44
+ --num_generations 8 \
45
+ --importance_sampling_level sequence \
46
+ --epsilon 3e-4 \
47
+ --epsilon_high 4e-4 \
48
+ --beta 0.0 \
49
+ --loss_type grpo \
50
+ --gradient_accumulation_steps 2 \
51
+ --steps_per_generation 8
52
+
53
+ """
54
+
55
+ import os
56
+
57
+ import torch
58
+ from datasets import load_dataset
59
+
60
+ from trl import (
61
+ GRPOConfig,
62
+ GRPOTrainer,
63
+ ModelConfig,
64
+ ScriptArguments,
65
+ TrlParser,
66
+ get_kbit_device_map,
67
+ get_peft_config,
68
+ get_quantization_config,
69
+ )
70
+ from trl.rewards import accuracy_reward, think_format_reward
71
+
72
+
73
+ # Enable logging in a Hugging Face Space
74
+ os.environ.setdefault("TRACKIO_SPACE_ID", "trl-trackio")
75
+
76
+ if __name__ == "__main__":
77
+ parser = TrlParser((ScriptArguments, GRPOConfig, ModelConfig))
78
+ script_args, training_args, model_args = parser.parse_args_and_config()
79
+ ################
80
+ # Model & Processor
81
+ ################
82
+ dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype)
83
+ training_args.model_init_kwargs = dict(
84
+ revision=model_args.model_revision,
85
+ attn_implementation=model_args.attn_implementation,
86
+ dtype=dtype,
87
+ )
88
+ quantization_config = get_quantization_config(model_args)
89
+ if quantization_config is not None:
90
+ # Passing None would not be treated the same as omitting the argument, so we include it only when valid.
91
+ training_args.model_init_kwargs["device_map"] = get_kbit_device_map()
92
+ training_args.model_init_kwargs["quantization_config"] = quantization_config
93
+
94
+ ################
95
+ # Dataset
96
+ ################
97
+ train_dataset, eval_dataset = load_dataset("AI-MO/NuminaMath-TIR", split=["train[:5%]", "test[:5%]"])
98
+
99
+ SYSTEM_PROMPT = (
100
+ "A conversation between user and assistant. The user asks a question, and the assistant solves it. The "
101
+ "assistant first thinks about the reasoning process in the mind and then provides the user with the answer. "
102
+ "The reasoning process and answer are enclosed within <think></think> tags, i.e., <think>\nThis is my "
103
+ "reasoning.\n</think>\nThis is my answer."
104
+ )
105
+
106
+ def make_conversation(example):
107
+ return {
108
+ "prompt": [
109
+ {"role": "system", "content": SYSTEM_PROMPT},
110
+ {"role": "user", "content": example["problem"]},
111
+ ],
112
+ }
113
+
114
+ train_dataset = train_dataset.map(make_conversation)
115
+ eval_dataset = eval_dataset.map(make_conversation)
116
+
117
+ train_dataset = train_dataset.remove_columns(["messages", "problem"])
118
+ eval_dataset = eval_dataset.remove_columns(["messages", "problem"])
119
+
120
+ ################
121
+ # Training
122
+ ################
123
+ trainer = GRPOTrainer(
124
+ model=model_args.model_name_or_path,
125
+ args=training_args,
126
+ reward_funcs=[think_format_reward, accuracy_reward],
127
+ train_dataset=train_dataset,
128
+ eval_dataset=eval_dataset,
129
+ peft_config=get_peft_config(model_args),
130
+ )
131
+
132
+ trainer.train()
133
+
134
+ # Save and push to hub
135
+ trainer.save_model(training_args.output_dir)
136
+ if training_args.push_to_hub:
137
+ trainer.push_to_hub(dataset_name=script_args.dataset_name)
ICL/RL/trl_source/examples/scripts/gspo_vlm.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020-2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # /// script
16
+ # dependencies = [
17
+ # "trl",
18
+ # "Pillow",
19
+ # "peft",
20
+ # "math-verify",
21
+ # "latex2sympy2_extended",
22
+ # "torchvision",
23
+ # "trackio",
24
+ # "kernels",
25
+ # ]
26
+ # ///
27
+
28
+ """
29
+ pip install math_verify
30
+
31
+ # For Qwen/Qwen2.5-VL-3B-Instruct
32
+ accelerate launch \
33
+ --config_file examples/accelerate_configs/deepspeed_zero3.yaml \
34
+ examples/scripts/gspo_vlm.py \
35
+ --model_name_or_path Qwen/Qwen2.5-VL-3B-Instruct \
36
+ --output_dir gspo-Qwen2.5-VL-3B-Instruct \
37
+ --learning_rate 1e-5 \
38
+ --dtype bfloat16 \
39
+ --max_completion_length 1024 \
40
+ --use_peft \
41
+ --lora_target_modules "q_proj", "v_proj" \
42
+ --log_completions \
43
+ --per_device_train_batch_size 8 \
44
+ --num_generations 8 \
45
+ --importance_sampling_level sequence \
46
+ --epsilon 3e-4 \
47
+ --epsilon_high 4e-4 \
48
+ --beta 0.0 \
49
+ --loss_type grpo \
50
+ --gradient_accumulation_steps 2 \
51
+ --steps_per_generation 8
52
+
53
+ """
54
+
55
+ import os
56
+
57
+ import torch
58
+ from datasets import load_dataset
59
+
60
+ from trl import (
61
+ GRPOConfig,
62
+ GRPOTrainer,
63
+ ModelConfig,
64
+ ScriptArguments,
65
+ TrlParser,
66
+ get_kbit_device_map,
67
+ get_peft_config,
68
+ get_quantization_config,
69
+ )
70
+ from trl.rewards import accuracy_reward, think_format_reward
71
+
72
+
73
+ # Enable logging in a Hugging Face Space
74
+ os.environ.setdefault("TRACKIO_SPACE_ID", "trl-trackio")
75
+
76
+
77
+ if __name__ == "__main__":
78
+ parser = TrlParser((ScriptArguments, GRPOConfig, ModelConfig))
79
+ script_args, training_args, model_args = parser.parse_args_and_config()
80
+ ################
81
+ # Model
82
+ ################
83
+ dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype)
84
+ training_args.model_init_kwargs = dict(
85
+ revision=model_args.model_revision,
86
+ attn_implementation=model_args.attn_implementation,
87
+ dtype=dtype,
88
+ )
89
+ quantization_config = get_quantization_config(model_args)
90
+ if quantization_config is not None:
91
+ # Passing None would not be treated the same as omitting the argument, so we include it only when valid.
92
+ training_args.model_init_kwargs["device_map"] = get_kbit_device_map()
93
+ training_args.model_init_kwargs["quantization_config"] = quantization_config
94
+
95
+ ################
96
+ # Dataset
97
+ ################
98
+ dataset = load_dataset("lmms-lab/multimodal-open-r1-8k-verified", split="train")
99
+ dataset = dataset.train_test_split(test_size=100, seed=42)
100
+
101
+ SYSTEM_PROMPT = (
102
+ "A conversation between user and assistant. The user asks a question, and the assistant solves it. The "
103
+ "assistant first thinks about the reasoning process in the mind and then provides the user with the answer. "
104
+ "The reasoning process and answer are enclosed within <think></think> tags, i.e., <think>\nThis is my "
105
+ "reasoning.\n</think>\nThis is my answer."
106
+ )
107
+
108
+ def make_conversation(example):
109
+ prompt = [
110
+ {"role": "system", "content": SYSTEM_PROMPT},
111
+ {"role": "user", "content": example["problem"]},
112
+ ]
113
+ return {"prompt": prompt}
114
+
115
+ dataset = dataset.map(make_conversation)
116
+
117
+ # Filter have big images
118
+ def filter_big_images(example):
119
+ image = example["image"]
120
+ return image.size[0] < 512 and image.size[1] < 512
121
+
122
+ dataset = dataset.filter(filter_big_images)
123
+
124
+ def convert_to_rgb(example):
125
+ image = example["image"]
126
+ if image.mode != "RGB":
127
+ image = image.convert("RGB")
128
+ example["image"] = image
129
+ return example
130
+
131
+ dataset = dataset.map(convert_to_rgb)
132
+
133
+ train_dataset = dataset["train"]
134
+ eval_dataset = dataset["test"] if training_args.eval_strategy != "no" else None
135
+
136
+ ################
137
+ # Training
138
+ ################
139
+ trainer = GRPOTrainer(
140
+ model=model_args.model_name_or_path,
141
+ args=training_args,
142
+ reward_funcs=[think_format_reward, accuracy_reward],
143
+ train_dataset=train_dataset,
144
+ eval_dataset=eval_dataset,
145
+ peft_config=get_peft_config(model_args),
146
+ )
147
+
148
+ trainer.train()
149
+
150
+ # Save and push to hub
151
+ trainer.save_model(training_args.output_dir)
152
+ if training_args.push_to_hub:
153
+ trainer.push_to_hub(dataset_name=script_args.dataset_name)
ICL/RL/trl_source/examples/scripts/kto.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020-2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # /// script
16
+ # dependencies = [
17
+ # "trl",
18
+ # "peft",
19
+ # "trackio",
20
+ # "kernels",
21
+ # ]
22
+ # ///
23
+
24
+ """
25
+ Run the KTO training script with the commands below. In general, the optimal configuration for KTO will be similar to that of DPO.
26
+
27
+ # Full training:
28
+ python trl/scripts/kto.py \
29
+ --dataset_name trl-lib/kto-mix-14k \
30
+ --model_name_or_path trl-lib/qwen1.5-1.8b-sft \
31
+ --per_device_train_batch_size 16 \
32
+ --num_train_epochs 1 \
33
+ --learning_rate 5e-7 \
34
+ --lr_scheduler_type cosine \
35
+ --gradient_accumulation_steps 1 \
36
+ --eval_steps 500 \
37
+ --output_dir kto-aligned-model \
38
+ --warmup_steps 0.1 \
39
+ --logging_first_step
40
+
41
+ # QLoRA:
42
+ python trl/scripts/kto.py \
43
+ --dataset_name trl-lib/kto-mix-14k \
44
+ --model_name_or_path trl-lib/qwen1.5-1.8b-sft \
45
+ --per_device_train_batch_size 8 \
46
+ --num_train_epochs 1 \
47
+ --learning_rate 5e-7 \
48
+ --lr_scheduler_type cosine \
49
+ --gradient_accumulation_steps 1 \
50
+ --eval_steps 500 \
51
+ --output_dir kto-aligned-model-lora \
52
+ --warmup_steps 0.1 \
53
+ --logging_first_step \
54
+ --use_peft \
55
+ --load_in_4bit \
56
+ --lora_target_modules all-linear \
57
+ --lora_r 16 \
58
+ --lora_alpha 16
59
+ """
60
+
61
+ import os
62
+
63
+ from datasets import load_dataset
64
+ from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser
65
+
66
+ from trl import ModelConfig, ScriptArguments, get_peft_config
67
+ from trl.experimental.kto import KTOConfig, KTOTrainer
68
+
69
+
70
+ # Enable logging in a Hugging Face Space
71
+ os.environ.setdefault("TRACKIO_SPACE_ID", "trl-trackio")
72
+
73
+
74
+ if __name__ == "__main__":
75
+ parser = HfArgumentParser((ScriptArguments, KTOConfig, ModelConfig))
76
+ script_args, training_args, model_args = parser.parse_args_into_dataclasses()
77
+
78
+ # Load a pretrained model
79
+ model = AutoModelForCausalLM.from_pretrained(
80
+ model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
81
+ )
82
+ ref_model = AutoModelForCausalLM.from_pretrained(
83
+ model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
84
+ )
85
+
86
+ tokenizer = AutoTokenizer.from_pretrained(
87
+ model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
88
+ )
89
+ if tokenizer.pad_token is None:
90
+ tokenizer.pad_token = tokenizer.eos_token
91
+
92
+ # Load the dataset
93
+ dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
94
+
95
+ # Initialize the KTO trainer
96
+ trainer = KTOTrainer(
97
+ model,
98
+ ref_model,
99
+ args=training_args,
100
+ train_dataset=dataset[script_args.dataset_train_split],
101
+ eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
102
+ processing_class=tokenizer,
103
+ peft_config=get_peft_config(model_args),
104
+ )
105
+
106
+ # Train and push the model to the Hub
107
+ trainer.train()
108
+
109
+ # Save and push to hub
110
+ trainer.save_model(training_args.output_dir)
111
+ if training_args.push_to_hub:
112
+ trainer.push_to_hub(dataset_name=script_args.dataset_name)
ICL/RL/trl_source/examples/scripts/mpo_vlm.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020-2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # /// script
16
+ # dependencies = [
17
+ # "trl",
18
+ # "Pillow",
19
+ # "peft",
20
+ # "torchvision",
21
+ # "trackio",
22
+ # "kernels",
23
+ # ]
24
+ # ///
25
+
26
+ """
27
+ python examples/scripts/mpo_vlm.py \
28
+ --dataset_name HuggingFaceH4/rlaif-v_formatted \
29
+ --model_name_or_path Qwen/Qwen2.5-VL-3B-Instruct \
30
+ --per_device_train_batch_size 4 \
31
+ --per_device_eval_batch_size 4 \
32
+ --num_train_epochs 1 \
33
+ --gradient_accumulation_steps 8 \
34
+ --dataset_num_proc 1 \
35
+ --output_dir dpo_idefics_rlaif-v \
36
+ --dtype bfloat16 \
37
+ --use_peft \
38
+ --lora_target_modules down_proj, o_proj, k_proj, q_proj, gate_proj, up_proj, v_proj \
39
+ --loss_type sigmoid bco_pair sft \
40
+ --loss_weights 0.8 0.2 1.0
41
+ """
42
+
43
+ import os
44
+
45
+ import torch
46
+ from datasets import load_dataset
47
+ from PIL import Image
48
+ from transformers import AutoModelForImageTextToText
49
+
50
+ from trl import (
51
+ DPOConfig,
52
+ DPOTrainer,
53
+ ModelConfig,
54
+ ScriptArguments,
55
+ TrlParser,
56
+ get_kbit_device_map,
57
+ get_peft_config,
58
+ get_quantization_config,
59
+ )
60
+
61
+
62
+ # Enable logging in a Hugging Face Space
63
+ os.environ.setdefault("TRACKIO_SPACE_ID", "trl-trackio")
64
+
65
+
66
+ if __name__ == "__main__":
67
+ parser = TrlParser((ScriptArguments, DPOConfig, ModelConfig))
68
+ script_args, training_args, model_args = parser.parse_args_and_config()
69
+
70
+ ################
71
+ # Model & Processor
72
+ ################
73
+ dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype)
74
+
75
+ model_kwargs = dict(
76
+ trust_remote_code=model_args.trust_remote_code,
77
+ revision=model_args.model_revision,
78
+ attn_implementation=model_args.attn_implementation,
79
+ dtype=dtype,
80
+ )
81
+ quantization_config = get_quantization_config(model_args)
82
+ if quantization_config is not None:
83
+ # Passing None would not be treated the same as omitting the argument, so we include it only when valid.
84
+ model_kwargs["device_map"] = get_kbit_device_map()
85
+ model_kwargs["quantization_config"] = quantization_config
86
+
87
+ model = AutoModelForImageTextToText.from_pretrained(
88
+ model_args.model_name_or_path,
89
+ **model_kwargs,
90
+ )
91
+ peft_config = get_peft_config(model_args)
92
+ if peft_config is None:
93
+ ref_model = AutoModelForImageTextToText.from_pretrained(
94
+ model_args.model_name_or_path,
95
+ **model_kwargs,
96
+ )
97
+ else:
98
+ ref_model = None
99
+
100
+ ################
101
+ # Dataset
102
+ ################
103
+ dataset = load_dataset(
104
+ script_args.dataset_name,
105
+ name=script_args.dataset_config,
106
+ streaming=script_args.dataset_streaming,
107
+ )
108
+ train_dataset = dataset[script_args.dataset_train_split]
109
+ test_dataset = dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None
110
+
111
+ def ensure_rgb(example):
112
+ # Convert the image to RGB if it's not already
113
+ image = example["images"][0]
114
+ if isinstance(image, Image.Image):
115
+ if image.mode != "RGB":
116
+ image = image.convert("RGB")
117
+ example["images"] = [image]
118
+ return example
119
+
120
+ # Apply the transformation to the dataset (change num_proc depending on the available compute)
121
+ train_dataset = train_dataset.map(ensure_rgb, num_proc=training_args.dataset_num_proc)
122
+ if test_dataset is not None:
123
+ test_dataset = test_dataset.map(ensure_rgb, num_proc=training_args.dataset_num_proc)
124
+
125
+ ################
126
+ # Training
127
+ ################
128
+ trainer = DPOTrainer(
129
+ model=model,
130
+ ref_model=ref_model,
131
+ args=training_args,
132
+ train_dataset=train_dataset,
133
+ eval_dataset=test_dataset,
134
+ peft_config=peft_config,
135
+ )
136
+
137
+ trainer.train()
138
+
139
+ # Save and push to hub
140
+ trainer.save_model(training_args.output_dir)
141
+ if training_args.push_to_hub:
142
+ trainer.push_to_hub(dataset_name=script_args.dataset_name)
ICL/RL/trl_source/examples/scripts/nash_md.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020-2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # /// script
16
+ # dependencies = [
17
+ # "trl",
18
+ # "trackio",
19
+ # "kernels",
20
+ # ]
21
+ # ///
22
+
23
+ """
24
+ Usage:
25
+
26
+ python examples/scripts/nash_md.py \
27
+ --model_name_or_path trl-lib/pythia-1b-deduped-tldr-sft \
28
+ --reward_model_path trl-lib/pythia-1b-deduped-tldr-rm \
29
+ --dataset_name trl-lib/tldr \
30
+ --learning_rate 5.0e-7 \
31
+ --output_dir pythia-1b-tldr-nash-md \
32
+ --per_device_train_batch_size 4 \
33
+ --gradient_accumulation_steps 32 \
34
+ --num_train_epochs 3 \
35
+ --max_new_tokens 64 \
36
+ --warmup_steps 0.1 \
37
+ --missing_eos_penalty 1.0 \
38
+ --push_to_hub
39
+
40
+
41
+ accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml \
42
+ examples/scripts/nash_md.py \
43
+ --model_name_or_path trl-lib/pythia-1b-deduped-tldr-sft \
44
+ --reward_model_path trl-lib/pythia-1b-deduped-tldr-rm \
45
+ --dataset_name trl-lib/tldr \
46
+ --learning_rate 5.0e-7 \
47
+ --output_dir pythia-1b-tldr-nash-md \
48
+ --per_device_train_batch_size 4 \
49
+ --gradient_accumulation_steps 32 \
50
+ --num_train_epochs 3 \
51
+ --max_new_tokens 64 \
52
+ --warmup_steps 0.1 \
53
+ --missing_eos_penalty 1.0 \
54
+ --push_to_hub
55
+ """
56
+
57
+ import os
58
+
59
+ import torch
60
+ from datasets import load_dataset
61
+ from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, GenerationConfig
62
+
63
+ from trl import (
64
+ LogCompletionsCallback,
65
+ ModelConfig,
66
+ ScriptArguments,
67
+ TrlParser,
68
+ get_kbit_device_map,
69
+ get_quantization_config,
70
+ )
71
+ from trl.experimental.judges import HfPairwiseJudge, OpenAIPairwiseJudge, PairRMJudge
72
+ from trl.experimental.nash_md import NashMDConfig, NashMDTrainer
73
+
74
+
75
+ # Enable logging in a Hugging Face Space
76
+ os.environ.setdefault("TRACKIO_SPACE_ID", "trl-trackio")
77
+
78
+
79
+ JUDGES = {"pair_rm": PairRMJudge, "openai": OpenAIPairwiseJudge, "hf": HfPairwiseJudge}
80
+
81
+ if __name__ == "__main__":
82
+ parser = TrlParser((ScriptArguments, NashMDConfig, ModelConfig))
83
+ script_args, training_args, model_args = parser.parse_args_and_config()
84
+ training_args.gradient_checkpointing_kwargs = {"use_reentrant": True}
85
+
86
+ dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype)
87
+ model_kwargs = dict(
88
+ revision=model_args.model_revision,
89
+ attn_implementation=model_args.attn_implementation,
90
+ dtype=dtype,
91
+ use_cache=False if training_args.gradient_checkpointing else True,
92
+ )
93
+ quantization_config = get_quantization_config(model_args)
94
+ if quantization_config is not None:
95
+ # Passing None would not be treated the same as omitting the argument, so we include it only when valid.
96
+ model_kwargs["device_map"] = get_kbit_device_map()
97
+ model_kwargs["quantization_config"] = quantization_config
98
+
99
+ model = AutoModelForCausalLM.from_pretrained(
100
+ model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs
101
+ )
102
+ ref_model = AutoModelForCausalLM.from_pretrained(
103
+ model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs
104
+ )
105
+
106
+ if training_args.reward_model_path is not None:
107
+ reward_model = AutoModelForSequenceClassification.from_pretrained(
108
+ training_args.reward_model_path,
109
+ num_labels=1,
110
+ trust_remote_code=model_args.trust_remote_code,
111
+ **model_kwargs,
112
+ )
113
+ else:
114
+ reward_model = None
115
+
116
+ if training_args.judge is not None:
117
+ judge_cls = JUDGES[training_args.judge]
118
+ judge = judge_cls()
119
+ else:
120
+ judge = None
121
+
122
+ tokenizer = AutoTokenizer.from_pretrained(
123
+ model_args.model_name_or_path, padding_side="left", trust_remote_code=model_args.trust_remote_code
124
+ )
125
+ if tokenizer.pad_token is None:
126
+ tokenizer.pad_token = tokenizer.eos_token
127
+
128
+ dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
129
+
130
+ trainer = NashMDTrainer(
131
+ model=model,
132
+ ref_model=ref_model,
133
+ reward_funcs=reward_model,
134
+ judge=judge,
135
+ args=training_args,
136
+ train_dataset=dataset[script_args.dataset_train_split],
137
+ eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
138
+ processing_class=tokenizer,
139
+ )
140
+
141
+ if training_args.eval_strategy != "no":
142
+ generation_config = GenerationConfig(
143
+ max_new_tokens=training_args.max_new_tokens, do_sample=True, temperature=training_args.temperature
144
+ )
145
+ completions_callback = LogCompletionsCallback(trainer, generation_config, num_prompts=8)
146
+ trainer.add_callback(completions_callback)
147
+
148
+ trainer.train()
149
+
150
+ # Save and push to hub
151
+ trainer.save_model(training_args.output_dir)
152
+ if training_args.push_to_hub:
153
+ trainer.push_to_hub(dataset_name=script_args.dataset_name)
ICL/RL/trl_source/examples/scripts/nemo_gym/README.md ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Post-training with NeMo Gym and TRL
2
+
3
+ This integration supports training language models in NeMo-Gym environments using TRL GRPO. Both single step and multi step tasks are supported, including multi-environment training. NeMo-Gym orchestrates rollouts, returning token ids and logprobs to TRL through the rollout function for training. Currently this integration is only supported through TRL's vllm server mode.
4
+
5
+ Check out the docs page `docs/source/nemo_gym.md` for a guide.