gemma-4-e2b-it-sequence-classification / modeling_gemma4_sequence.py
baseten-admin's picture
Update modeling_gemma4_sequence.py
2bf0822 verified
"""Gemma 4 sequence classifier backed by selected next-token logits.
This module is intentionally small: it reuses the Gemma 4 multimodal backbone and
replaces the LM head with a classifier head containing selected token rows.
"""
from __future__ import annotations
from collections.abc import Sequence
import torch
from torch import nn
from transformers.modeling_outputs import SequenceClassifierOutputWithPast
from transformers.models.gemma4.configuration_gemma4 import Gemma4Config
from transformers.models.gemma4.modeling_gemma4 import Gemma4Model, Gemma4PreTrainedModel
class Gemma4ForSequenceClassification(Gemma4PreTrainedModel):
"""Pool the last text position and score it with selected Gemma 4 token rows."""
config_class = Gemma4Config
base_model_prefix = "model"
@classmethod
def _can_set_experts_implementation(cls) -> bool:
return True
def __init__(
self,
config: Gemma4Config,
source_model: nn.Module | None = None,
classifier_weight: torch.Tensor | None = None,
) -> None:
super().__init__(config)
self.num_labels = config.num_labels
self.model = source_model.model if source_model is not None else Gemma4Model(config)
self.score = nn.Linear(config.text_config.hidden_size, self.num_labels, bias=False)
if classifier_weight is not None:
self.score.to(device=classifier_weight.device, dtype=classifier_weight.dtype)
self.score.weight.data.copy_(classifier_weight)
if source_model is None and classifier_weight is None:
self.post_init()
@classmethod
def from_conditional_generation(
cls,
model_lm: nn.Module,
selected_token_ids: Sequence[int],
labels: Sequence[str],
) -> "Gemma4ForSequenceClassification":
token_ids = torch.tensor(selected_token_ids, device=model_lm.lm_head.weight.device)
classifier_weight = model_lm.lm_head.weight.index_select(0, token_ids).detach().clone()
cls.configure_classification_config(model_lm.config, selected_token_ids, labels)
return cls(model_lm.config, source_model=model_lm, classifier_weight=classifier_weight)
@classmethod
def configure_classification_config(
cls,
config: Gemma4Config,
selected_token_ids: Sequence[int],
labels: Sequence[str],
) -> None:
config.num_labels = len(labels)
config.id2label = {i: label for i, label in enumerate(labels)}
config.label2id = {label: i for i, label in enumerate(labels)}
config.classifier_token_ids = {
label: int(token_id) for label, token_id in zip(labels, selected_token_ids)
}
config.architectures = [cls.__name__]
config.problem_type = "single_label_classification"
if getattr(config, "pad_token_id", None) is None:
config.pad_token_id = config.text_config.pad_token_id
def get_input_embeddings(self):
return self.model.get_input_embeddings()
def set_input_embeddings(self, value):
self.model.set_input_embeddings(value)
def get_per_layer_input_embeddings(self):
return self.model.get_per_layer_input_embeddings()
def set_per_layer_input_embeddings(self, value):
self.model.set_per_layer_input_embeddings(value)
def _last_non_pad_token(
self,
logits: torch.Tensor,
input_ids: torch.LongTensor | None,
attention_mask: torch.Tensor | None,
inputs_embeds: torch.FloatTensor | None,
) -> torch.Tensor | int:
batch_size = logits.shape[0]
if attention_mask is not None:
token_indices = torch.arange(logits.shape[1], device=logits.device)
return (attention_mask.to(logits.device) * token_indices).argmax(-1)
pad_token_id = getattr(self.config, "pad_token_id", None)
if input_ids is not None and pad_token_id is not None:
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
non_pad = input_ids.to(logits.device).ne(pad_token_id)
return (non_pad * token_indices).argmax(-1)
if batch_size != 1:
raise ValueError(
"Cannot infer sequence lengths for a padded batch without a pad token."
)
if input_ids is None and inputs_embeds is None:
raise ValueError("Expected input_ids or inputs_embeds.")
return -1
def _apply_final_logit_softcapping(self, logits: torch.Tensor) -> torch.Tensor:
final_logit_softcapping = self.config.get_text_config().final_logit_softcapping
if final_logit_softcapping is None:
return logits
logits = logits / final_logit_softcapping
logits = torch.tanh(logits)
return logits * final_logit_softcapping
def forward(
self,
input_ids: torch.LongTensor | None = None,
pixel_values: torch.FloatTensor | None = None,
pixel_values_videos: torch.FloatTensor | None = None,
input_features: torch.FloatTensor | None = None,
attention_mask: torch.Tensor | None = None,
input_features_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
image_position_ids: torch.LongTensor | None = None,
video_position_ids: torch.LongTensor | None = None,
past_key_values=None,
mm_token_type_ids: torch.LongTensor | None = None,
inputs_embeds: torch.FloatTensor | None = None,
labels: torch.LongTensor | None = None,
use_cache: bool | None = None,
return_dict: bool | None = None,
**kwargs,
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.model(
input_ids=input_ids,
pixel_values=pixel_values,
pixel_values_videos=pixel_values_videos,
input_features=input_features,
attention_mask=attention_mask,
input_features_mask=input_features_mask,
position_ids=position_ids,
past_key_values=past_key_values,
mm_token_type_ids=mm_token_type_ids,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
image_position_ids=image_position_ids,
video_position_ids=video_position_ids,
return_dict=True,
**kwargs,
)
logits = self.score(outputs.last_hidden_state)
logits = self._apply_final_logit_softcapping(logits)
sequence_lengths = self._last_non_pad_token(
logits,
input_ids,
attention_mask,
inputs_embeds,
)
pooled_logits = logits[
torch.arange(logits.shape[0], device=logits.device),
sequence_lengths,
]
loss = None
if labels is not None:
labels = labels.to(pooled_logits.device)
if self.config.problem_type == "regression":
loss = nn.MSELoss()(pooled_logits.squeeze(), labels.squeeze())
elif self.config.problem_type == "multi_label_classification":
loss = nn.BCEWithLogitsLoss()(pooled_logits, labels)
else:
loss = nn.CrossEntropyLoss()(
pooled_logits.view(-1, self.num_labels),
labels.view(-1),
)
if not return_dict:
output = (pooled_logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutputWithPast(
loss=loss,
logits=pooled_logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
Gemma4ForSequenceClassification.register_for_auto_class("AutoModelForSequenceClassification")