"""Gemma 4 sequence classifier backed by selected next-token logits. This module is intentionally small: it reuses the Gemma 4 multimodal backbone and replaces the LM head with a classifier head containing selected token rows. """ from __future__ import annotations from collections.abc import Sequence import torch from torch import nn from transformers.modeling_outputs import SequenceClassifierOutputWithPast from transformers.models.gemma4.configuration_gemma4 import Gemma4Config from transformers.models.gemma4.modeling_gemma4 import Gemma4Model, Gemma4PreTrainedModel class Gemma4ForSequenceClassification(Gemma4PreTrainedModel): """Pool the last text position and score it with selected Gemma 4 token rows.""" config_class = Gemma4Config base_model_prefix = "model" @classmethod def _can_set_experts_implementation(cls) -> bool: return True def __init__( self, config: Gemma4Config, source_model: nn.Module | None = None, classifier_weight: torch.Tensor | None = None, ) -> None: super().__init__(config) self.num_labels = config.num_labels self.model = source_model.model if source_model is not None else Gemma4Model(config) self.score = nn.Linear(config.text_config.hidden_size, self.num_labels, bias=False) if classifier_weight is not None: self.score.to(device=classifier_weight.device, dtype=classifier_weight.dtype) self.score.weight.data.copy_(classifier_weight) if source_model is None and classifier_weight is None: self.post_init() @classmethod def from_conditional_generation( cls, model_lm: nn.Module, selected_token_ids: Sequence[int], labels: Sequence[str], ) -> "Gemma4ForSequenceClassification": token_ids = torch.tensor(selected_token_ids, device=model_lm.lm_head.weight.device) classifier_weight = model_lm.lm_head.weight.index_select(0, token_ids).detach().clone() cls.configure_classification_config(model_lm.config, selected_token_ids, labels) return cls(model_lm.config, source_model=model_lm, classifier_weight=classifier_weight) @classmethod def configure_classification_config( cls, config: Gemma4Config, selected_token_ids: Sequence[int], labels: Sequence[str], ) -> None: config.num_labels = len(labels) config.id2label = {i: label for i, label in enumerate(labels)} config.label2id = {label: i for i, label in enumerate(labels)} config.classifier_token_ids = { label: int(token_id) for label, token_id in zip(labels, selected_token_ids) } config.architectures = [cls.__name__] config.problem_type = "single_label_classification" if getattr(config, "pad_token_id", None) is None: config.pad_token_id = config.text_config.pad_token_id def get_input_embeddings(self): return self.model.get_input_embeddings() def set_input_embeddings(self, value): self.model.set_input_embeddings(value) def get_per_layer_input_embeddings(self): return self.model.get_per_layer_input_embeddings() def set_per_layer_input_embeddings(self, value): self.model.set_per_layer_input_embeddings(value) def _last_non_pad_token( self, logits: torch.Tensor, input_ids: torch.LongTensor | None, attention_mask: torch.Tensor | None, inputs_embeds: torch.FloatTensor | None, ) -> torch.Tensor | int: batch_size = logits.shape[0] if attention_mask is not None: token_indices = torch.arange(logits.shape[1], device=logits.device) return (attention_mask.to(logits.device) * token_indices).argmax(-1) pad_token_id = getattr(self.config, "pad_token_id", None) if input_ids is not None and pad_token_id is not None: token_indices = torch.arange(input_ids.shape[-1], device=logits.device) non_pad = input_ids.to(logits.device).ne(pad_token_id) return (non_pad * token_indices).argmax(-1) if batch_size != 1: raise ValueError( "Cannot infer sequence lengths for a padded batch without a pad token." ) if input_ids is None and inputs_embeds is None: raise ValueError("Expected input_ids or inputs_embeds.") return -1 def _apply_final_logit_softcapping(self, logits: torch.Tensor) -> torch.Tensor: final_logit_softcapping = self.config.get_text_config().final_logit_softcapping if final_logit_softcapping is None: return logits logits = logits / final_logit_softcapping logits = torch.tanh(logits) return logits * final_logit_softcapping def forward( self, input_ids: torch.LongTensor | None = None, pixel_values: torch.FloatTensor | None = None, pixel_values_videos: torch.FloatTensor | None = None, input_features: torch.FloatTensor | None = None, attention_mask: torch.Tensor | None = None, input_features_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, image_position_ids: torch.LongTensor | None = None, video_position_ids: torch.LongTensor | None = None, past_key_values=None, mm_token_type_ids: torch.LongTensor | None = None, inputs_embeds: torch.FloatTensor | None = None, labels: torch.LongTensor | None = None, use_cache: bool | None = None, return_dict: bool | None = None, **kwargs, ): return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.model( input_ids=input_ids, pixel_values=pixel_values, pixel_values_videos=pixel_values_videos, input_features=input_features, attention_mask=attention_mask, input_features_mask=input_features_mask, position_ids=position_ids, past_key_values=past_key_values, mm_token_type_ids=mm_token_type_ids, inputs_embeds=inputs_embeds, use_cache=use_cache, image_position_ids=image_position_ids, video_position_ids=video_position_ids, return_dict=True, **kwargs, ) logits = self.score(outputs.last_hidden_state) logits = self._apply_final_logit_softcapping(logits) sequence_lengths = self._last_non_pad_token( logits, input_ids, attention_mask, inputs_embeds, ) pooled_logits = logits[ torch.arange(logits.shape[0], device=logits.device), sequence_lengths, ] loss = None if labels is not None: labels = labels.to(pooled_logits.device) if self.config.problem_type == "regression": loss = nn.MSELoss()(pooled_logits.squeeze(), labels.squeeze()) elif self.config.problem_type == "multi_label_classification": loss = nn.BCEWithLogitsLoss()(pooled_logits, labels) else: loss = nn.CrossEntropyLoss()( pooled_logits.view(-1, self.num_labels), labels.view(-1), ) if not return_dict: output = (pooled_logits,) + outputs[1:] return ((loss,) + output) if loss is not None else output return SequenceClassifierOutputWithPast( loss=loss, logits=pooled_logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) Gemma4ForSequenceClassification.register_for_auto_class("AutoModelForSequenceClassification")