Instructions to use baseten/gemma-4-e2b-it-sequence-classification with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use baseten/gemma-4-e2b-it-sequence-classification with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-classification", model="baseten/gemma-4-e2b-it-sequence-classification", trust_remote_code=True)# Load model directly from transformers import AutoProcessor, AutoModelForSequenceClassification processor = AutoProcessor.from_pretrained("baseten/gemma-4-e2b-it-sequence-classification", trust_remote_code=True) model = AutoModelForSequenceClassification.from_pretrained("baseten/gemma-4-e2b-it-sequence-classification", trust_remote_code=True) - Notebooks
- Google Colab
- Kaggle
| """Gemma 4 sequence classifier backed by selected next-token logits. | |
| This module is intentionally small: it reuses the Gemma 4 multimodal backbone and | |
| replaces the LM head with a classifier head containing selected token rows. | |
| """ | |
| from __future__ import annotations | |
| from collections.abc import Sequence | |
| import torch | |
| from torch import nn | |
| from transformers.modeling_outputs import SequenceClassifierOutputWithPast | |
| from transformers.models.gemma4.configuration_gemma4 import Gemma4Config | |
| from transformers.models.gemma4.modeling_gemma4 import Gemma4Model, Gemma4PreTrainedModel | |
| class Gemma4ForSequenceClassification(Gemma4PreTrainedModel): | |
| """Pool the last text position and score it with selected Gemma 4 token rows.""" | |
| config_class = Gemma4Config | |
| base_model_prefix = "model" | |
| def _can_set_experts_implementation(cls) -> bool: | |
| return True | |
| def __init__( | |
| self, | |
| config: Gemma4Config, | |
| source_model: nn.Module | None = None, | |
| classifier_weight: torch.Tensor | None = None, | |
| ) -> None: | |
| super().__init__(config) | |
| self.num_labels = config.num_labels | |
| self.model = source_model.model if source_model is not None else Gemma4Model(config) | |
| self.score = nn.Linear(config.text_config.hidden_size, self.num_labels, bias=False) | |
| if classifier_weight is not None: | |
| self.score.to(device=classifier_weight.device, dtype=classifier_weight.dtype) | |
| self.score.weight.data.copy_(classifier_weight) | |
| if source_model is None and classifier_weight is None: | |
| self.post_init() | |
| def from_conditional_generation( | |
| cls, | |
| model_lm: nn.Module, | |
| selected_token_ids: Sequence[int], | |
| labels: Sequence[str], | |
| ) -> "Gemma4ForSequenceClassification": | |
| token_ids = torch.tensor(selected_token_ids, device=model_lm.lm_head.weight.device) | |
| classifier_weight = model_lm.lm_head.weight.index_select(0, token_ids).detach().clone() | |
| cls.configure_classification_config(model_lm.config, selected_token_ids, labels) | |
| return cls(model_lm.config, source_model=model_lm, classifier_weight=classifier_weight) | |
| def configure_classification_config( | |
| cls, | |
| config: Gemma4Config, | |
| selected_token_ids: Sequence[int], | |
| labels: Sequence[str], | |
| ) -> None: | |
| config.num_labels = len(labels) | |
| config.id2label = {i: label for i, label in enumerate(labels)} | |
| config.label2id = {label: i for i, label in enumerate(labels)} | |
| config.classifier_token_ids = { | |
| label: int(token_id) for label, token_id in zip(labels, selected_token_ids) | |
| } | |
| config.architectures = [cls.__name__] | |
| config.problem_type = "single_label_classification" | |
| if getattr(config, "pad_token_id", None) is None: | |
| config.pad_token_id = config.text_config.pad_token_id | |
| def get_input_embeddings(self): | |
| return self.model.get_input_embeddings() | |
| def set_input_embeddings(self, value): | |
| self.model.set_input_embeddings(value) | |
| def get_per_layer_input_embeddings(self): | |
| return self.model.get_per_layer_input_embeddings() | |
| def set_per_layer_input_embeddings(self, value): | |
| self.model.set_per_layer_input_embeddings(value) | |
| def _last_non_pad_token( | |
| self, | |
| logits: torch.Tensor, | |
| input_ids: torch.LongTensor | None, | |
| attention_mask: torch.Tensor | None, | |
| inputs_embeds: torch.FloatTensor | None, | |
| ) -> torch.Tensor | int: | |
| batch_size = logits.shape[0] | |
| if attention_mask is not None: | |
| token_indices = torch.arange(logits.shape[1], device=logits.device) | |
| return (attention_mask.to(logits.device) * token_indices).argmax(-1) | |
| pad_token_id = getattr(self.config, "pad_token_id", None) | |
| if input_ids is not None and pad_token_id is not None: | |
| token_indices = torch.arange(input_ids.shape[-1], device=logits.device) | |
| non_pad = input_ids.to(logits.device).ne(pad_token_id) | |
| return (non_pad * token_indices).argmax(-1) | |
| if batch_size != 1: | |
| raise ValueError( | |
| "Cannot infer sequence lengths for a padded batch without a pad token." | |
| ) | |
| if input_ids is None and inputs_embeds is None: | |
| raise ValueError("Expected input_ids or inputs_embeds.") | |
| return -1 | |
| def _apply_final_logit_softcapping(self, logits: torch.Tensor) -> torch.Tensor: | |
| final_logit_softcapping = self.config.get_text_config().final_logit_softcapping | |
| if final_logit_softcapping is None: | |
| return logits | |
| logits = logits / final_logit_softcapping | |
| logits = torch.tanh(logits) | |
| return logits * final_logit_softcapping | |
| def forward( | |
| self, | |
| input_ids: torch.LongTensor | None = None, | |
| pixel_values: torch.FloatTensor | None = None, | |
| pixel_values_videos: torch.FloatTensor | None = None, | |
| input_features: torch.FloatTensor | None = None, | |
| attention_mask: torch.Tensor | None = None, | |
| input_features_mask: torch.Tensor | None = None, | |
| position_ids: torch.LongTensor | None = None, | |
| image_position_ids: torch.LongTensor | None = None, | |
| video_position_ids: torch.LongTensor | None = None, | |
| past_key_values=None, | |
| mm_token_type_ids: torch.LongTensor | None = None, | |
| inputs_embeds: torch.FloatTensor | None = None, | |
| labels: torch.LongTensor | None = None, | |
| use_cache: bool | None = None, | |
| return_dict: bool | None = None, | |
| **kwargs, | |
| ): | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| outputs = self.model( | |
| input_ids=input_ids, | |
| pixel_values=pixel_values, | |
| pixel_values_videos=pixel_values_videos, | |
| input_features=input_features, | |
| attention_mask=attention_mask, | |
| input_features_mask=input_features_mask, | |
| position_ids=position_ids, | |
| past_key_values=past_key_values, | |
| mm_token_type_ids=mm_token_type_ids, | |
| inputs_embeds=inputs_embeds, | |
| use_cache=use_cache, | |
| image_position_ids=image_position_ids, | |
| video_position_ids=video_position_ids, | |
| return_dict=True, | |
| **kwargs, | |
| ) | |
| logits = self.score(outputs.last_hidden_state) | |
| logits = self._apply_final_logit_softcapping(logits) | |
| sequence_lengths = self._last_non_pad_token( | |
| logits, | |
| input_ids, | |
| attention_mask, | |
| inputs_embeds, | |
| ) | |
| pooled_logits = logits[ | |
| torch.arange(logits.shape[0], device=logits.device), | |
| sequence_lengths, | |
| ] | |
| loss = None | |
| if labels is not None: | |
| labels = labels.to(pooled_logits.device) | |
| if self.config.problem_type == "regression": | |
| loss = nn.MSELoss()(pooled_logits.squeeze(), labels.squeeze()) | |
| elif self.config.problem_type == "multi_label_classification": | |
| loss = nn.BCEWithLogitsLoss()(pooled_logits, labels) | |
| else: | |
| loss = nn.CrossEntropyLoss()( | |
| pooled_logits.view(-1, self.num_labels), | |
| labels.view(-1), | |
| ) | |
| if not return_dict: | |
| output = (pooled_logits,) + outputs[1:] | |
| return ((loss,) + output) if loss is not None else output | |
| return SequenceClassifierOutputWithPast( | |
| loss=loss, | |
| logits=pooled_logits, | |
| past_key_values=outputs.past_key_values, | |
| hidden_states=outputs.hidden_states, | |
| attentions=outputs.attentions, | |
| ) | |
| Gemma4ForSequenceClassification.register_for_auto_class("AutoModelForSequenceClassification") | |