# Gemma 4 Assistant

## Overview

Gemma 4 Assistant is a small, text-only model that enables speculative decoding with for Gemma 4 models using the
Multi-Token Prediction (MTP) method and associated candidate generator. Pre-trained models are provided for the IT
vairants of the Gemma 4 E2B, E4B, 31B and 26B-A4B (MoE) models.

Architecturally, the Gemma 4 Assistant shares the same [`Gemma4TextModel` backbone](gemma4#transformers.Gemma4TextModel)
as other Gemma 4 models, but differs in a few key ways:

*   **The entire model uses KV sharing**. This technique, originally introduced with [Gemma 3n](./gemma3n), allows the
    model to resuse the KV cache populated by the target model the assistant supports, allowing the assistant to skip
    the pre-fille phase entirely, and considerably reducing attention compute during the forward pass.
*   **The `position_ids` value are constant**. Since the KV cache is shared and the assistant does not have a mean of
    updating the cache, the assistant predicts all tokens from the same position ID.
*   **Inputs are the concatenation of embeddings and hidden states**. To adapt for the static KV cache and
    `position_ids`, the model takes its inputs as the concatenation of the `embedding` and `hidden_states` for the last
    seen token from the target model and projects them into assistant model space with a `nn.Linear` transform. The
    definition of last seen token changes throughout the assisted decoding loop. For the first token drafted after
    pre-fill, the last seen token will be the last token from the prompt. For subsequent drafting steps, the last seen
    token will be the last token generated by the assistant (within a drafting round) or the last token accepted by the
    target model (between drafting rounds).
*   **Cross-attention is used to make the most of the target model's context**. Cross-attention allows the query states
    geneated by the assistant to attend to the shared KV cache values from the target model, allowing the assistant to
    accurately predict more drafted tokens per drafting round.

You can find all the original Gemma 4 Assistant checkpoints under the
[Gemma 4](https://huggingface.co/collections/google/gemma-4) release.

## Usage examples

The example below demonstrates how to generate text based on an image with [Pipeline](/docs/transformers/main/en/main_classes/pipelines#transformers.Pipeline) or the [AutoModel](/docs/transformers/main/en/model_doc/auto#transformers.AutoModel) class.

```py
import torch
from transformers import pipeline

pipeline = pipeline(
    task="image-text-to-text",
    model="google/gemma-4-E2B-it",
    assistant_model="google/gemma-4-E2B-it-assistant",
)
pipeline(
    images="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg",
    text="<|image|>\n\nWhat is shown in this image?"
)
```

```py
import torch
from transformers import AutoProcessor, AutoModelForImageTextToText

model = AutoModelForImageTextToText.from_pretrained(
    "google/gemma-4-E2B-it",
    dtype=torch.bfloat16,
    device_map="auto",
)
assistant_model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-4-E2B-it-assistant",
    dtype=torch.bfloat16,
    device_map="auto",
)

processor = AutoProcessor.from_pretrained(
    "google/gemma-4-E2B-it",
    padding_side="left"
)
messages = [
    {
        "role": "user", "content": [
            {"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"},
            {"type": "text", "text": "What is shown in this image?"},
        ]
    },
]
inputs = processor.apply_chat_template(
    messages,
    tokenize=True,
    return_dict=True,
    return_tensors="pt",
    add_generation_prompt=True,
).to(model.device)
input_len = inputs["input_ids"].shape[-1]

output = model.generate(**inputs, max_new_tokens=50, assistant_model=assistant_model)
print(processor.decode(output[0][input_len:], skip_special_tokens=True))
```

## Gemma4AssistantConfig[[transformers.Gemma4AssistantConfig]]

#### transformers.Gemma4AssistantConfig[[transformers.Gemma4AssistantConfig]]

[Source](https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma4_assistant/configuration_gemma4_assistant.py#L29)

This is the configuration class to store the configuration of a Gemma4 AssistantModel. It is used to instantiate a Gemma4 Assistant
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the [google/gemma-4-e2b-it](https://huggingface.co/google/gemma-4-e2b-it)

Configuration objects inherit from [PreTrainedConfig](/docs/transformers/main/en/main_classes/configuration#transformers.PreTrainedConfig) and can be used to control the model outputs. Read the
documentation from [PreTrainedConfig](/docs/transformers/main/en/main_classes/configuration#transformers.PreTrainedConfig) for more information.

Example:

```python
>>> from transformers import (
>>>     Gemma4AssistantConfig,
>>>     Gemma4AssistantForCausalLM,
>>>     Gemma4TextConfig,
>>> )

>>> # Initializing a Gemma 4 Text config similar to google/gemma-4-e2b-it-assistant.
>>> text_config = Gemma4TextConfig(...)

>>> # Initializing a Gemma 4 Assistant config similar to google/gemma-4-e2b-it-assistant.
>>> configuration = Gemma4AssistantConfig(text_config)

>>> # Initializing a model from the google/gemma-4-e2b-it-assistant configuration.
>>> model = Gemma4AssistantForCausalLM(configuration)

>>> # Accessing the model configuration
>>> configuration = model.config
```

**Parameters:**

text_config (`Union[~models.gemma4.configuration_gemma4.Gemma4TextConfig, dict[str, Any]]`, *optional*) : The config object or dictionary of the text backbone.

backbone_hidden_size (`int`, defaults to 1536) : Hidden size of the target model this assistant was trained with.

use_ordered_embeddings (`bool`, defaults to False) : If True, uses an embedding table ordered for optimal assistant model performance that needs to be re-ordered to align with the main model.

num_centroids (`int`, defaults to 2048) : The total numer of centroids.

centroid_intermediate_top_k (`int`, defaults to 32) : The number of active centroids.

tie_word_embeddings (`bool`, *optional*, defaults to `True`) : Whether to tie weight embeddings according to model's `tied_weights_keys` mapping.

## Gemma4AssistantForCausalLM[[transformers.Gemma4AssistantForCausalLM]]

#### transformers.Gemma4AssistantForCausalLM[[transformers.Gemma4AssistantForCausalLM]]

[Source](https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma4_assistant/modeling_gemma4_assistant.py#L110)

A model for mutli-token prediction-based assisted decoding with Gemma 4.

This model inherits from [PreTrainedModel](/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel). Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)

This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.

create_attention_maskstransformers.Gemma4AssistantForCausalLM.create_attention_maskshttps://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma4_assistant/modeling_gemma4_assistant.py#L197[{"name": "inputs_embeds", "val": ""}, {"name": "attention_mask", "val": ""}, {"name": "shared_kv_states", "val": ""}]

Prepare the attention masks for the assisted model; the `shared_kv_states` acts as past cache in this instance.

We use bidirectional masks to account for causality
- There is no difference for the edge case of `q_len == 1` as it acts as full attention no matter what
- SWA interprets the window as forward-looking (future) when `q_idx=1` and `kv>=1`
  - We switch from a future to a past perspective by flipping on the kv axis
  - To account for position invariant padding, we also flip the base attention mask before initial creation

**Parameters:**

config ([Gemma4AssistantConfig](/docs/transformers/main/en/model_doc/gemma4_assistant#transformers.Gemma4AssistantConfig)) : Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the [from_pretrained()](/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel.from_pretrained) method to load the model weights.
#### forward[[transformers.Gemma4AssistantForCausalLM.forward]]

[Source](https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma4_assistant/modeling_gemma4_assistant.py#L133)

The [Gemma4AssistantForCausalLM](/docs/transformers/main/en/model_doc/gemma4_assistant#transformers.Gemma4AssistantForCausalLM) forward method, overrides the `__call__` special method.

Although the recipe for forward pass needs to be defined within this function, one should call the `Module`
instance afterwards instead of this since the former takes care of running the pre and post processing steps while
the latter silently ignores them.

Example:

```python
>>> from transformers import AutoTokenizer, Gemma4AssistantForCausalLM, Gemma4ForCausalLM

>>> model = Gemma4ForCausalLM.from_pretrained("google/gemma-4-e2b-it")
>>> assistant_model = Gemma4AssistantForCausalLM.from_pretrained("google/gemma-4-e2b-it-assistant")
>>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-4-e2b-it")

>>> prompt = "What is your favorite condiment?"
>>> inputs = tokenizer(prompt, return_tensors="pt")

>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, assistant_model=assistant_model, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True)[0]
"What is your favorite condiment?"
```

**Parameters:**

input_ids (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*) : Indices of input sequence tokens in the vocabulary. Padding will be ignored by default.  Indices can be obtained using [AutoTokenizer](/docs/transformers/main/en/model_doc/auto#transformers.AutoTokenizer). See [PreTrainedTokenizer.encode()](/docs/transformers/main/en/internal/tokenization_utils#transformers.PreTrainedTokenizerBase.encode) and [PreTrainedTokenizer.__call__()](/docs/transformers/main/en/internal/tokenization_utils#transformers.PreTrainedTokenizerBase.__call__) for details.  [What are input IDs?](../glossary#input-ids)

inputs_embeds (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) : Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix.

position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*) : Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, config.n_positions - 1]`.  [What are position IDs?](../glossary#position-ids)

attention_mask (`dict[str, torch.Tensor]` of shape `(batch_size, sequence_length)`, *optional*) : Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:  - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**.  [What are attention masks?](../glossary#attention-mask)

shared_kv_states (`dict[str, torch.Tensor` of shape `(batch_size, 1, q_len, kv_len)`, *optional*) : A dictionary containing the computed KV values for the last layer of each `layer_type` in this model.

use_cache (`bool`, *optional*) : If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`).

**Returns:**

`tuple[torch.Tensor, torch.Tensor]`

