# General Online Logit Distillation (GOLD) Trainer

[![All_models-GOLD-blue](https://img.shields.io/badge/All_models-GOLD-blue)](https://huggingface.co/models?other=sft,gold)

## Overview

General Online Logit Distillation (GOLD) is an extension of Universal Logit Distillation (ULD) that supports
student/teacher pairs with different tokenizers. It aligns the textual spans produced by both tokenizers and merges the
associated logits so no completion tokens are dropped. This enables cross-tokenizer knowledge distillation, including
mixed model families (for example, LLaMA students with Qwen teachers).

Key capabilities:

1. **Cross-tokenizer alignment** – GOLD incrementally decodes the student and teacher tokens, groups passages with the same visible text, and merges probabilities inside each group. This guarantees loss terms are computed over the full completion even when token boundaries differ.
2. **Hybrid ULD loss** – when `uld_use_hybrid_loss` is enabled, GOLD compares exact vocabulary matches directly and falls back to the original sorted-probability ULD loss for unmatched tokens. This improves stability for students whose vocabularies only partially overlap with the teacher.
3. **Seamless integration with GKD** – GOLD inherits the on-policy vs. off-policy scheduling from the [experimental.gkd.GKDTrainer](/docs/trl/v1.3.0/en/gkd_trainer#trl.experimental.gkd.GKDTrainer), so you can combine sequence-level KD, generalized JSD, and cross-tokenizer distillation in a single training run.

> [!NOTE]
> GOLD is currently part of the `trl.experimental` namespace. APIs may change without notice while the feature is iterated on.

## Usage tips

The `GOLDTrainer` subclasses [SFTTrainer](/docs/trl/v1.3.0/en/sft_trainer#trl.SFTTrainer) and accepts the same datasets as other TRL trainers (lists of ChatML style
messages). Important configuration flags on `GOLDConfig` include:

* `use_uld_loss` – toggles Universal Logit Distillation. Set this to `True` for cross-tokenizer setups.
* `teacher_tokenizer_name_or_path` – required when `use_uld_loss=True`; GOLD uses the teacher tokenizer to align tokens.
* `uld_use_hybrid_loss`, `uld_hybrid_matched_weight`, `uld_hybrid_unmatched_weight` – enables and weights the hybrid
  matched/unmatched loss.
* `beta`, `lmbda`, `seq_kd` – inherited from [experimental.gkd.GKDConfig](/docs/trl/v1.3.0/en/gkd_trainer#trl.experimental.gkd.GKDConfig), controlling the generalized JSD interpolation and on-policy
  sampling ratio.
* `num_generations`, `generation_batch_size` – control buffered rollout generation across gradient accumulation windows.
  `generation_batch_size` is the number of unique prompts per worker per optimizer step.
* `model_revision` – controls which student model revision GOLD loads for training and generation.

A minimal end-to-end example:

```python
from datasets import load_dataset
from trl.experimental.gold import GOLDConfig, GOLDTrainer

train_dataset = load_dataset(
    "HuggingFaceTB/OpenR1-Math-220k-default-verified",
    "all",
    split="train[:1024]",
)

trainer = GOLDTrainer(
    model="meta-llama/Llama-3.2-1B-Instruct",
    teacher_model="Qwen/Qwen2.5-0.5B-Instruct",
    args=GOLDConfig(output_dir="gold-model", use_uld_loss=True, teacher_tokenizer_name_or_path="Qwen/Qwen2.5-0.5B-Instruct"),
    train_dataset=train_dataset,
)
trainer.train()
```

For quick-start workflows you can rely on string identifiers as shown above—the trainer will load the model and tokenizer for you. Explicitly instantiating `AutoModelForCausalLM`, `AutoTokenizer`, or populating `GOLDConfig` is recommended only for advanced use cases where you need fine-grained control over initialization.

A more explicit setup might look like this when you need to customise model loading, tokenizer settings, or training arguments:

```python
from datasets import load_dataset
from trl import GOLDConfig, GOLDTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer

student_name = "meta-llama/Llama-3.2-1B-Instruct"
teacher_name = "Qwen/Qwen2.5-0.5B-Instruct"

tokenizer = AutoTokenizer.from_pretrained(student_name)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(student_name)
teacher_model = AutoModelForCausalLM.from_pretrained(teacher_name)

train_dataset = load_dataset(
    "HuggingFaceTB/Countdown-Task-GOLD",
    "verified_Qwen2.5-0.5B-Instruct",
    split="train",
)

training_args = GOLDConfig(
    output_dir="gold-model",
    per_device_train_batch_size=1,
    teacher_model_name_or_path=teacher_name,
    teacher_tokenizer_name_or_path=teacher_name,
    use_uld_loss=True,
    uld_use_hybrid_loss=True,
)

trainer = GOLDTrainer(
    model=model,
    teacher_model=teacher_model,
    args=training_args,
    processing_class=tokenizer,
    train_dataset=train_dataset,
)
trainer.train()
```

> [!NOTE]
> GOLD buffers one full optimizer-window generation batch (`per_device_train_batch_size * gradient_accumulation_steps`)
> and reuses it across accumulation steps. If the final batch is undersized, GOLD warns and drops that last batch
> (`Dropping last batch due to unexpected batch size`). Set `dataloader_drop_last=True` to avoid this warning.

### Expected dataset type

GOLD requires a [conversational](dataset_formats#conversational) [language modeling](dataset_formats#language_modeling) dataset, e.g.:

```python
{"messages": [{"role": "user", "content": "What color is the sky?"},
              {"role": "assistant", "content": "It is blue."}]}
```

`GOLDTrainer` keeps the raw messages so the ChatML collator can construct prompts and completions with the correct
boundaries.

## How Token Merging Works

When student and teacher use different tokenizers, the same text may be split differently:

- **Student**: `"Hugging Face"` → 1 token
- **Teacher**: `"Hugging"`, `" Face"` → 2 tokens

GOLD aligns these sequences and merges the teacher's multi-token probabilities into a single distribution that can be compared with the student's single-token distribution.

### Probability Merging

For a teacher sequence of tokens `[token₀, token₁, ..., tokenₖ]` that maps to a single student token, GOLD computes:

```
P_merged(y) = P(y | context) × P(token₁ | token₀, context) × ... × P(tokenₖ | ..., context)
```

where:
- `P(y | context)` is the marginal probability distribution over all vocabulary tokens at the first position
- `P(tokenᵢ | ..., context)` are **scalar** conditional probabilities of the actual tokens that were generated

**Key insight**: Only the conditional probabilities of the **actual continuation tokens** are extracted as scalars. The full marginal distribution at the first position is then scaled by multiplying these scalar probabilities.

This ensures:
1. **Correct joint probability** for the actual generated sequence (by the chain rule)
2. **Reasonable approximation** for counterfactual tokens (scaled by the same continuation likelihood)
3. **Unnormalized distributions** that preserve the correct relative probabilities for ULD loss computation

### Example

Given:
```
P(x₀):         ["HF": 0.6,  "is": 0.3,  "cool": 0.1]
P(x₁ | "HF"):  ["HF": 0.05, "is": 0.9,  "cool": 0.05]
```

If tokens 0 and 1 are merged, and the actual sequence was `["HF", "is"]`:
```
P_merged("HF")   = 0.6 × 0.9 = 0.54  ✓ (correct joint probability)
P_merged("is")   = 0.3 × 0.9 = 0.27
P_merged("cool") = 0.1 × 0.9 = 0.09
```

The merged distribution is unnormalized (sums to 0.81), but this is intentional and correct for ULD loss computation, which uses sorting and L1 distance.

## Example script

Use [`trl/experimental/gold/gold.py`](https://github.com/huggingface/trl/blob/main/trl/experimental/gold/gold.py) to launch GOLD training from the command line. The script supports full training and LoRA via the standard `ModelConfig` flags.

```bash
python trl/experimental/gold/gold.py \
    --model_name_or_path meta-llama/Llama-3.2-1B-Instruct \
    --teacher_model_name_or_path Qwen/Qwen2-1.5B-Instruct \
    --dataset_name trl-lib/chatbot_arena_completions \
    --learning_rate 2e-5 \
    --per_device_train_batch_size 4 \
    --gradient_accumulation_steps 8 \
    --output_dir gold-model \
    --num_train_epochs 1 \
    --push_to_hub
```

## GOLDTrainer[[trl.experimental.gold.GOLDTrainer]]

#### trl.experimental.gold.GOLDTrainer[[trl.experimental.gold.GOLDTrainer]]

[Source](https://github.com/huggingface/trl/blob/v1.3.0/trl/experimental/gold/gold_trainer.py#L739)

traintrl.experimental.gold.GOLDTrainer.trainhttps://github.com/huggingface/trl/blob/v1.3.0/transformers/trainer.py#L1325[{"name": "resume_from_checkpoint", "val": ": str | bool | None = None"}, {"name": "trial", "val": ": optuna.Trial | dict[str, Any] | None = None"}, {"name": "ignore_keys_for_eval", "val": ": list[str] | None = None"}]- **resume_from_checkpoint** (`str` or `bool`, *optional*) --
  If a `str`, local path to a saved checkpoint as saved by a previous instance of `Trainer`. If a
  `bool` and equals `True`, load the last checkpoint in *args.output_dir* as saved by a previous instance
  of `Trainer`. If present, training will resume from the model/optimizer/scheduler states loaded here.
- **trial** (`optuna.Trial` or `dict[str, Any]`, *optional*) --
  The trial run or the hyperparameter dictionary for hyperparameter search.
- **ignore_keys_for_eval** (`list[str]`, *optional*) --
  A list of keys in the output of your model (if it is a dictionary) that should be ignored when
  gathering predictions for evaluation during the training.0`~trainer_utils.TrainOutput`Object containing the global step count, training loss, and metrics.

Main training entry point.

**Parameters:**

resume_from_checkpoint (`str` or `bool`, *optional*) : If a `str`, local path to a saved checkpoint as saved by a previous instance of `Trainer`. If a `bool` and equals `True`, load the last checkpoint in *args.output_dir* as saved by a previous instance of `Trainer`. If present, training will resume from the model/optimizer/scheduler states loaded here.

trial (`optuna.Trial` or `dict[str, Any]`, *optional*) : The trial run or the hyperparameter dictionary for hyperparameter search.

ignore_keys_for_eval (`list[str]`, *optional*) : A list of keys in the output of your model (if it is a dictionary) that should be ignored when gathering predictions for evaluation during the training.

**Returns:**

``~trainer_utils.TrainOutput``

Object containing the global step count, training loss, and metrics.
#### generate_on_policy_outputs[[trl.experimental.gold.GOLDTrainer.generate_on_policy_outputs]]

[Source](https://github.com/huggingface/trl/blob/v1.3.0/trl/experimental/gold/gold_trainer.py#L1870)
#### save_model[[trl.experimental.gold.GOLDTrainer.save_model]]

[Source](https://github.com/huggingface/trl/blob/v1.3.0/transformers/trainer.py#L3752)

Will save the model, so you can reload it using `from_pretrained()`.

Will only save from the main process.
#### push_to_hub[[trl.experimental.gold.GOLDTrainer.push_to_hub]]

[Source](https://github.com/huggingface/trl/blob/v1.3.0/transformers/trainer.py#L3999)

Upload `self.model` and `self.processing_class` to the 🤗 model hub on the repo `self.args.hub_model_id`.

**Parameters:**

commit_message (`str`, *optional*, defaults to `"End of training"`) : Message to commit while pushing.

blocking (`bool`, *optional*, defaults to `True`) : Whether the function should return only when the `git push` has finished.

token (`str`, *optional*, defaults to `None`) : Token with write permission to overwrite Trainer's original args.

revision (`str`, *optional*) : The git revision to commit from. Defaults to the head of the "main" branch.

kwargs (`dict[str, Any]`, *optional*) : Additional keyword arguments passed along to `~Trainer.create_model_card`.

**Returns:**

The URL of the repository where the model was pushed if `blocking=False`, or a `Future` object tracking the
progress of the commit if `blocking=True`.

## GOLDConfig[[trl.experimental.gold.GOLDConfig]]

#### trl.experimental.gold.GOLDConfig[[trl.experimental.gold.GOLDConfig]]

[Source](https://github.com/huggingface/trl/blob/v1.3.0/trl/experimental/gold/gold_config.py#L23)

Configuration class for `GOLDTrainer`.

This class includes only the parameters that are specific to GOLD training. For a full list of training arguments,
please refer to the [TrainingArguments](https://huggingface.co/docs/transformers/v5.6.2/en/main_classes/trainer#transformers.TrainingArguments) and [SFTConfig](/docs/trl/v1.3.0/en/sft_trainer#trl.SFTConfig) documentation.

**Parameters:**

temperature (`float`, *optional*, defaults to `0.9`) : Temperature for sampling. The higher the temperature, the more random the completions.

lmbda (`float`, *optional*, defaults to `0.5`) : Lambda parameter that controls the student data fraction (i.e., the proportion of on-policy student-generated outputs).

beta (`float`, *optional*, defaults to `0.5`) : Interpolation coefficient between `0.0` and `1.0` of the Generalized Jensen-Shannon Divergence loss. When beta is `0.0`, the loss is the KL divergence. When beta is `1.0`, the loss is the Inverse KL Divergence.

max_completion_length (`int`, *optional*, defaults to `128`) : Maximum number of tokens to generate per completion.

teacher_model_name_or_path (`str`, *optional*) : Model name or path of the teacher model. If `None`, the teacher model will be the same as the model being trained.

teacher_model_revision (`str` or `None`, *optional*, defaults to `None`) : Model revision of the teacher model (e.g., branch name, tag, or commit hash). If `None`, the default revision is used.

teacher_model_init_kwargs (`dict[str, Any]`, *optional*) : Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the teacher model from a string.

teacher_tokenizer_name_or_path (`str`, *optional*) : Tokenizer name or path for the teacher model. If None when using ULD loss, will use the same tokenizer as the student model (not recommended for cross-tokenizer distillation).

disable_dropout (`bool`, *optional*, defaults to `True`) : Whether to disable dropout in the model.

seq_kd (`bool`, *optional*, defaults to `False`) : Seq_kd parameter that controls whether to perform Sequence-Level KD (can be viewed as supervised FT on teacher-generated output).

num_generations (`int`, *optional*, defaults to `1`) : Number of generations per prompt. Each prompt is repeated this many times in the generation batch.

generation_batch_size (`int` or `None`, *optional*, defaults to `None`) : Number of unique prompts per worker per optimizer step. If `None`, it is computed from `(per_device_train_batch_size * gradient_accumulation_steps) // num_generations`.

use_uld_loss (`bool`, *optional*, defaults to `False`) : Whether to use Universal Logit Distillation (ULD) loss instead of Generalized Jensen-Shannon Divergence loss.

uld_crossentropy_weight (`float`, *optional*, defaults to `0.0`) : Weight for the cross-entropy loss component in ULD loss. If 0, only ULD distillation loss is used.

uld_distillation_weight (`float`, *optional*, defaults to `1.0`) : Weight for the distillation loss component in ULD loss.

uld_student_temperature (`float`, *optional*, defaults to `1.0`) : Temperature for student logits in ULD loss computation.

uld_teacher_temperature (`float`, *optional*, defaults to `1.0`) : Temperature for teacher logits in ULD loss computation.

uld_skip_student_eos (`bool`, *optional*, defaults to `True`) : Whether to skip EOS token for student in ULD loss computation.

uld_skip_teacher_eos (`bool`, *optional*, defaults to `True`) : Whether to skip EOS token for teacher in ULD loss computation.

use_vllm (`bool`, *optional*, defaults to `False`) : Whether to use vLLM for generating completions from the student model. Requires `vllm` to be installed.

vllm_mode (`str`, *optional*, defaults to `"colocate"`) : Mode for student vLLM integration. Either `"server"` (connect to a running TRL vLLM server) or `"colocate"` (run vLLM in the same process).

vllm_server_host (`str`, *optional*, defaults to `"0.0.0.0"`) : Host of the vLLM server for the student model (if `vllm_mode="server"`).

vllm_server_port (`int`, *optional*, defaults to `8001`) : Port of the vLLM server for the student model (if `vllm_mode="server"`).

vllm_server_timeout (`float`, *optional*, defaults to `240.0`) : Timeout for connecting to the student vLLM server (if `vllm_mode="server"`).

vllm_gpu_memory_utilization (`float`, *optional*, defaults to `0.9`) : GPU memory utilization for the colocated student vLLM engine (if `vllm_mode="colocate"`). It is recommended to set this to a low value if the student and teacher models share the same GPU.

vllm_tensor_parallel_size (`int`, *optional*, defaults to `1`) : Tensor parallel size for the colocated student vLLM engine (if `vllm_mode="colocate"`).

vllm_structured_outputs_regex (`str`, *optional*) : Regex for vLLM structured outputs for the student model.

vllm_server_base_url (`str`, *optional*) : Base URL for the vLLM server (e.g., `"http://localhost:8001"`). If provided, `vllm_server_host` and `vllm_server_port` are ignored.

vllm_group_port (`int`, *optional*, defaults to `51216`) : Port for the vLLM weight-update group (NCCL communicator). Unless the port is occupied, there is no need to change it.

vllm_max_model_length (`int`, *optional*) : Maximum model sequence length for the colocated vLLM engine when `vllm_mode="colocate"`. Defaults to the model's maximum context length.

vllm_model_impl (`str`, *optional*, defaults to `"vllm"`) : Model implementation backend to use in vLLM. Use `"vllm"` (default) or `"transformers"`.

vllm_sync_frequency (`int`, *optional*, defaults to `1`) : Frequency (in training steps) to synchronize student model weights to vLLM engine. Set to 1 to sync after every step.

vllm_enable_sleep_mode (`bool`, *optional*, defaults to `False`) : Enable vLLM sleep mode to offload student weights/cache during the optimizer step. Keeps GPU memory usage low, but waking the engine adds host–device transfer latency.

