| | from typing import ClassVar, Optional, Union |
| |
|
| | import torch |
| | import torch.utils.checkpoint |
| | from torch import nn |
| |
|
| | from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration |
| |
|
| | from ...cache_utils import Cache |
| |
|
| |
|
| | class NewTaskModelForNewTask(PaliGemmaForConditionalGeneration): |
| | main_input_name: ClassVar[str] = "doc_input_ids" |
| |
|
| | def __init__(self, config): |
| | super().__init__(config=config) |
| |
|
| | self.embedding_dim = self.config.embedding_dim |
| | self.custom_text_proj = nn.Linear(self.config.text_config.hidden_size, self.embedding_dim) |
| |
|
| | if self.language_model._tied_weights_keys is not None: |
| | self._tied_weights_keys = [f"model.language_model.{k}" for k in self.language_model._tied_weights_keys] |
| |
|
| | self.post_init() |
| |
|
| | def forward( |
| | self, |
| | input_ids: torch.LongTensor = None, |
| | pixel_values: torch.FloatTensor = None, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | position_ids: Optional[torch.LongTensor] = None, |
| | past_key_values: Optional[Union[list[torch.FloatTensor], Cache]] = None, |
| | token_type_ids: Optional[torch.LongTensor] = None, |
| | cache_position: Optional[torch.LongTensor] = None, |
| | inputs_embeds: Optional[torch.FloatTensor] = None, |
| | labels: Optional[torch.LongTensor] = None, |
| | use_cache: Optional[bool] = None, |
| | output_attentions: Optional[bool] = None, |
| | output_hidden_states: Optional[bool] = None, |
| | return_dict: Optional[bool] = None, |
| | num_logits_to_keep: int = 0, |
| | ): |
| | r""" |
| | Returns: |
| | """ |
| | vlm_outputs = super().forward( |
| | input_ids=input_ids, |
| | pixel_values=pixel_values, |
| | attention_mask=attention_mask, |
| | position_ids=position_ids, |
| | past_key_values=past_key_values, |
| | token_type_ids=token_type_ids, |
| | cache_position=cache_position, |
| | inputs_embeds=inputs_embeds, |
| | labels=labels, |
| | use_cache=use_cache, |
| | output_attentions=output_attentions, |
| | output_hidden_states=True, |
| | return_dict=True, |
| | num_logits_to_keep=num_logits_to_keep, |
| | ) |
| | last_hidden_states = vlm_outputs.hidden_states[-1] |
| | proj = self.custom_text_proj(last_hidden_states) |
| |
|
| | |
| | embeddings = proj / proj.norm(dim=-1, keepdim=True) |
| |
|
| | embeddings = embeddings * attention_mask.unsqueeze(-1) |
| |
|
| | return (embeddings,) + vlm_outputs |
| |
|
| | def resize_token_embeddings( |
| | self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None, mean_resizing=True |
| | ) -> nn.Embedding: |
| | model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing) |
| |
|
| | |
| | self.config.text_config.vocab_size = model_embeds.num_embeddings |
| | self.config.vocab_size = model_embeds.num_embeddings |
| | self.vocab_size = model_embeds.num_embeddings |
| |
|
| | return model_embeds |
| |
|