| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | from transformers.modeling_utils import PreTrainedModel |
| | from transformers.utils import ( |
| | ModelOutput, |
| | add_start_docstrings, |
| | add_start_docstrings_to_model_forward, |
| | is_flash_attn_2_available, |
| | logging, |
| | replace_return_docstrings, |
| | is_flash_attn_2_available, |
| | is_flash_attn_greater_or_equal_2_10, |
| | ) |
| | from transformers.activations import ACT2FN |
| | from transformers.modeling_attn_mask_utils import ( |
| | _prepare_4d_attention_mask, |
| | _prepare_4d_attention_mask_for_sdpa, |
| | _prepare_4d_causal_attention_mask, |
| | _prepare_4d_causal_attention_mask_for_sdpa, |
| | ) |
| | from transformers.modeling_outputs import ( |
| | BaseModelOutput, |
| | BaseModelOutputWithPastAndCrossAttentions, |
| | Seq2SeqLMOutput, |
| | Seq2SeqModelOutput, |
| | ) |
| |
|
| | from transformers.cache_utils import Cache, HybridCache |
| | from transformers.modeling_outputs import ( |
| | BaseModelOutputWithPast, |
| | CausalLMOutputWithPast, |
| | SequenceClassifierOutputWithPast, |
| | TokenClassifierOutput, |
| | ) |
| |
|
| | from typing import List, Optional, Tuple, Union |
| |
|
| | from transformers.models.gemma2.modeling_gemma2 import Gemma2Model, Gemma2ForCausalLM,Gemma2DecoderLayer,Gemma2RMSNorm |
| | from .configuration_feynmodel import FeynModelConfig,Florence2VisionConfig |
| |
|
| | from transformers import AutoProcessor, AutoTokenizer, AutoModelForCausalLM |
| | import json |
| | import math |
| | import torch |
| | from torch import nn |
| | import torch.nn.functional as F |
| | import logging |
| |
|
| | from transformers.utils import ( |
| | ModelOutput, |
| | add_start_docstrings, |
| | add_start_docstrings_to_model_forward, |
| | is_flash_attn_2_available, |
| | logging, |
| | replace_return_docstrings, |
| | is_flash_attn_2_available, |
| | is_flash_attn_greater_or_equal_2_10, |
| | ) |
| |
|
| | from transformers.modeling_utils import PreTrainedModel |
| |
|
| | from collections import OrderedDict |
| | from einops import rearrange |
| | from timm.models.layers import DropPath, trunc_normal_ |
| |
|
| | logger = logging.get_logger(__name__) |
| |
|
| | class MySequential(nn.Sequential): |
| | def forward(self, *inputs): |
| | for module in self._modules.values(): |
| | if type(inputs) == tuple: |
| | inputs = module(*inputs) |
| | else: |
| | inputs = module(inputs) |
| | return inputs |
| |
|
| |
|
| | class PreNorm(nn.Module): |
| | def __init__(self, norm, fn, drop_path=None): |
| | super().__init__() |
| | self.norm = norm |
| | self.fn = fn |
| | self.drop_path = drop_path |
| |
|
| | def forward(self, x, *args, **kwargs): |
| | shortcut = x |
| | if self.norm != None: |
| | x, size = self.fn(self.norm(x), *args, **kwargs) |
| | else: |
| | x, size = self.fn(x, *args, **kwargs) |
| |
|
| | if self.drop_path: |
| | x = self.drop_path(x) |
| |
|
| | x = shortcut + x |
| |
|
| | return x, size |
| |
|
| |
|
| | class Mlp(nn.Module): |
| | def __init__( |
| | self, |
| | in_features, |
| | hidden_features=None, |
| | out_features=None, |
| | act_layer=nn.GELU, |
| | ): |
| | super().__init__() |
| | out_features = out_features or in_features |
| | hidden_features = hidden_features or in_features |
| | self.net = nn.Sequential(OrderedDict([ |
| | ("fc1", nn.Linear(in_features, hidden_features)), |
| | ("act", act_layer()), |
| | ("fc2", nn.Linear(hidden_features, out_features)) |
| | ])) |
| |
|
| | def forward(self, x, size): |
| | return self.net(x), size |
| |
|
| |
|
| | class DepthWiseConv2d(nn.Module): |
| | def __init__( |
| | self, |
| | dim_in, |
| | kernel_size, |
| | padding, |
| | stride, |
| | bias=True, |
| | ): |
| | super().__init__() |
| | self.dw = nn.Conv2d( |
| | dim_in, dim_in, |
| | kernel_size=kernel_size, |
| | padding=padding, |
| | groups=dim_in, |
| | stride=stride, |
| | bias=bias |
| | ) |
| |
|
| | def forward(self, x, size): |
| | B, N, C = x.shape |
| | H, W = size |
| | assert N == H * W |
| |
|
| | x = self.dw(x.transpose(1, 2).view(B, C, H, W)) |
| | size = (x.size(-2), x.size(-1)) |
| | x = x.flatten(2).transpose(1, 2) |
| | return x, size |
| |
|
| |
|
| | class ConvEmbed(nn.Module): |
| | """ Image to Patch Embedding |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | patch_size=7, |
| | in_chans=3, |
| | embed_dim=64, |
| | stride=4, |
| | padding=2, |
| | norm_layer=None, |
| | pre_norm=True |
| | ): |
| | super().__init__() |
| | self.patch_size = patch_size |
| |
|
| | self.proj = nn.Conv2d( |
| | in_chans, embed_dim, |
| | kernel_size=patch_size, |
| | stride=stride, |
| | padding=padding |
| | ) |
| |
|
| | dim_norm = in_chans if pre_norm else embed_dim |
| | self.norm = norm_layer(dim_norm) if norm_layer else None |
| |
|
| | self.pre_norm = pre_norm |
| |
|
| | def forward(self, x, size): |
| | H, W = size |
| | if len(x.size()) == 3: |
| | if self.norm and self.pre_norm: |
| | x = self.norm(x) |
| | x = rearrange( |
| | x, 'b (h w) c -> b c h w', |
| | h=H, w=W |
| | ) |
| |
|
| | x = self.proj(x) |
| |
|
| | _, _, H, W = x.shape |
| | x = rearrange(x, 'b c h w -> b (h w) c') |
| | if self.norm and not self.pre_norm: |
| | x = self.norm(x) |
| |
|
| | return x, (H, W) |
| |
|
| |
|
| | class ChannelAttention(nn.Module): |
| |
|
| | def __init__(self, dim, groups=8, qkv_bias=True): |
| | super().__init__() |
| |
|
| | self.groups = groups |
| | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) |
| | self.proj = nn.Linear(dim, dim) |
| |
|
| | def forward(self, x, size): |
| | B, N, C = x.shape |
| |
|
| | qkv = self.qkv(x).reshape(B, N, 3, self.groups, C // self.groups).permute(2, 0, 3, 1, 4) |
| | q, k, v = qkv[0], qkv[1], qkv[2] |
| |
|
| | q = q * (float(N) ** -0.5) |
| | attention = q.transpose(-1, -2) @ k |
| | attention = attention.softmax(dim=-1) |
| | x = (attention @ v.transpose(-1, -2)).transpose(-1, -2) |
| | x = x.transpose(1, 2).reshape(B, N, C) |
| | x = self.proj(x) |
| | return x, size |
| |
|
| |
|
| | class ChannelBlock(nn.Module): |
| |
|
| | def __init__(self, dim, groups, mlp_ratio=4., qkv_bias=True, |
| | drop_path_rate=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, |
| | conv_at_attn=True, conv_at_ffn=True): |
| | super().__init__() |
| |
|
| | drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() |
| |
|
| | self.conv1 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_attn else None |
| | self.channel_attn = PreNorm( |
| | norm_layer(dim), |
| | ChannelAttention(dim, groups=groups, qkv_bias=qkv_bias), |
| | drop_path |
| | ) |
| | self.conv2 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_ffn else None |
| | self.ffn = PreNorm( |
| | norm_layer(dim), |
| | Mlp(in_features=dim, hidden_features=int(dim*mlp_ratio), act_layer=act_layer), |
| | drop_path |
| | ) |
| |
|
| | def forward(self, x, size): |
| | if self.conv1: |
| | x, size = self.conv1(x, size) |
| | x, size = self.channel_attn(x, size) |
| |
|
| | if self.conv2: |
| | x, size = self.conv2(x, size) |
| | x, size = self.ffn(x, size) |
| |
|
| | return x, size |
| |
|
| |
|
| | def window_partition(x, window_size: int): |
| | B, H, W, C = x.shape |
| | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) |
| | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) |
| | return windows |
| |
|
| |
|
| | def window_reverse(windows, batch_size: int, window_size: int, H: int, W: int): |
| | B = batch_size |
| | |
| | |
| | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) |
| | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) |
| | return x |
| |
|
| |
|
| | class WindowAttention(nn.Module): |
| | def __init__(self, dim, num_heads, window_size, qkv_bias=True): |
| |
|
| | super().__init__() |
| | self.dim = dim |
| | self.window_size = window_size |
| | self.num_heads = num_heads |
| | head_dim = dim // num_heads |
| | self.scale = float(head_dim) ** -0.5 |
| |
|
| | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) |
| | self.proj = nn.Linear(dim, dim) |
| |
|
| | self.softmax = nn.Softmax(dim=-1) |
| |
|
| | def forward(self, x, size): |
| |
|
| | H, W = size |
| | B, L, C = x.shape |
| | assert L == H * W, "input feature has wrong size" |
| |
|
| | x = x.view(B, H, W, C) |
| |
|
| | pad_l = pad_t = 0 |
| | pad_r = (self.window_size - W % self.window_size) % self.window_size |
| | pad_b = (self.window_size - H % self.window_size) % self.window_size |
| | x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) |
| | _, Hp, Wp, _ = x.shape |
| |
|
| | x = window_partition(x, self.window_size) |
| | x = x.view(-1, self.window_size * self.window_size, C) |
| |
|
| | |
| | |
| |
|
| | B_, N, C = x.shape |
| | qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) |
| | q, k, v = qkv[0], qkv[1], qkv[2] |
| |
|
| | q = q * self.scale |
| | attn = (q @ k.transpose(-2, -1)) |
| | attn = self.softmax(attn) |
| |
|
| | x = (attn @ v).transpose(1, 2).reshape(B_, N, C) |
| | x = self.proj(x) |
| |
|
| | |
| | x = x.view( |
| | -1, self.window_size, self.window_size, C |
| | ) |
| | x = window_reverse(x, B, self.window_size, Hp, Wp) |
| |
|
| | if pad_r > 0 or pad_b > 0: |
| | x = x[:, :H, :W, :].contiguous() |
| |
|
| | x = x.view(B, H * W, C) |
| |
|
| | return x, size |
| |
|
| |
|
| | class SpatialBlock(nn.Module): |
| |
|
| | def __init__(self, dim, num_heads, window_size, |
| | mlp_ratio=4., qkv_bias=True, drop_path_rate=0., act_layer=nn.GELU, |
| | norm_layer=nn.LayerNorm, conv_at_attn=True, conv_at_ffn=True): |
| | super().__init__() |
| |
|
| | drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() |
| |
|
| | self.conv1 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_attn else None |
| | self.window_attn = PreNorm( |
| | norm_layer(dim), |
| | WindowAttention(dim, num_heads, window_size, qkv_bias=qkv_bias), |
| | drop_path |
| | ) |
| | self.conv2 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_ffn else None |
| | self.ffn = PreNorm( |
| | norm_layer(dim), |
| | Mlp(in_features=dim, hidden_features=int(dim*mlp_ratio), act_layer=act_layer), |
| | drop_path |
| | ) |
| |
|
| | def forward(self, x, size): |
| | if self.conv1: |
| | x, size = self.conv1(x, size) |
| | x, size = self.window_attn(x, size) |
| |
|
| | if self.conv2: |
| | x, size = self.conv2(x, size) |
| | x, size = self.ffn(x, size) |
| | return x, size |
| |
|
| |
|
| | class DaViT(nn.Module): |
| | """ DaViT: Dual-Attention Transformer |
| | |
| | Args: |
| | in_chans (int): Number of input image channels. Default: 3. |
| | num_classes (int): Number of classes for classification head. Default: 1000. |
| | patch_size (tuple(int)): Patch size of convolution in different stages. Default: (7, 2, 2, 2). |
| | patch_stride (tuple(int)): Patch stride of convolution in different stages. Default: (4, 2, 2, 2). |
| | patch_padding (tuple(int)): Patch padding of convolution in different stages. Default: (3, 0, 0, 0). |
| | patch_prenorm (tuple(bool)): If True, perform norm before convlution layer. Default: (True, False, False, False). |
| | embed_dims (tuple(int)): Patch embedding dimension in different stages. Default: (64, 128, 192, 256). |
| | num_heads (tuple(int)): Number of spatial attention heads in different stages. Default: (4, 8, 12, 16). |
| | num_groups (tuple(int)): Number of channel groups in different stages. Default: (4, 8, 12, 16). |
| | window_size (int): Window size. Default: 7. |
| | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. |
| | qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True. |
| | drop_path_rate (float): Stochastic depth rate. Default: 0.1. |
| | norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. |
| | enable_checkpoint (bool): If True, enable checkpointing. Default: False. |
| | conv_at_attn (bool): If True, performe depthwise convolution before attention layer. Default: True. |
| | conv_at_ffn (bool): If True, performe depthwise convolution before ffn layer. Default: True. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | in_chans=3, |
| | num_classes=1000, |
| | depths=(1, 1, 3, 1), |
| | patch_size=(7, 2, 2, 2), |
| | patch_stride=(4, 2, 2, 2), |
| | patch_padding=(3, 0, 0, 0), |
| | patch_prenorm=(False, False, False, False), |
| | embed_dims=(64, 128, 192, 256), |
| | num_heads=(3, 6, 12, 24), |
| | num_groups=(3, 6, 12, 24), |
| | window_size=7, |
| | mlp_ratio=4., |
| | qkv_bias=True, |
| | drop_path_rate=0.1, |
| | norm_layer=nn.LayerNorm, |
| | enable_checkpoint=False, |
| | conv_at_attn=True, |
| | conv_at_ffn=True, |
| | ): |
| | super().__init__() |
| |
|
| | self.num_classes = num_classes |
| | self.embed_dims = embed_dims |
| | self.num_heads = num_heads |
| | self.num_groups = num_groups |
| | self.num_stages = len(self.embed_dims) |
| | self.enable_checkpoint = enable_checkpoint |
| | assert self.num_stages == len(self.num_heads) == len(self.num_groups) |
| |
|
| | num_stages = len(embed_dims) |
| | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)*2)] |
| |
|
| | depth_offset = 0 |
| | convs = [] |
| | blocks = [] |
| | for i in range(num_stages): |
| | conv_embed = ConvEmbed( |
| | patch_size=patch_size[i], |
| | stride=patch_stride[i], |
| | padding=patch_padding[i], |
| | in_chans=in_chans if i == 0 else self.embed_dims[i - 1], |
| | embed_dim=self.embed_dims[i], |
| | norm_layer=norm_layer, |
| | pre_norm=patch_prenorm[i] |
| | ) |
| | convs.append(conv_embed) |
| |
|
| | block = MySequential( |
| | *[ |
| | MySequential(OrderedDict([ |
| | ( |
| | 'spatial_block', SpatialBlock( |
| | embed_dims[i], |
| | num_heads[i], |
| | window_size, |
| | drop_path_rate=dpr[depth_offset+j*2], |
| | qkv_bias=qkv_bias, |
| | mlp_ratio=mlp_ratio, |
| | conv_at_attn=conv_at_attn, |
| | conv_at_ffn=conv_at_ffn, |
| | ) |
| | ), |
| | ( |
| | 'channel_block', ChannelBlock( |
| | embed_dims[i], |
| | num_groups[i], |
| | drop_path_rate=dpr[depth_offset+j*2+1], |
| | qkv_bias=qkv_bias, |
| | mlp_ratio=mlp_ratio, |
| | conv_at_attn=conv_at_attn, |
| | conv_at_ffn=conv_at_ffn, |
| | ) |
| | ) |
| | ])) for j in range(depths[i]) |
| | ] |
| | ) |
| | blocks.append(block) |
| | depth_offset += depths[i]*2 |
| |
|
| | self.convs = nn.ModuleList(convs) |
| | self.blocks = nn.ModuleList(blocks) |
| |
|
| | self.norms = norm_layer(self.embed_dims[-1]) |
| | self.avgpool = nn.AdaptiveAvgPool1d(1) |
| | self.head = nn.Linear(self.embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity() |
| |
|
| | self.apply(self._init_weights) |
| |
|
| | @property |
| | def dim_out(self): |
| | return self.embed_dims[-1] |
| |
|
| | def _init_weights(self, m): |
| | if isinstance(m, nn.Linear): |
| | trunc_normal_(m.weight, std=0.02) |
| | if m.bias is not None: |
| | nn.init.constant_(m.bias, 0) |
| | elif isinstance(m, nn.Conv2d): |
| | nn.init.normal_(m.weight, std=0.02) |
| | for name, _ in m.named_parameters(): |
| | if name in ['bias']: |
| | nn.init.constant_(m.bias, 0) |
| | elif isinstance(m, nn.LayerNorm): |
| | nn.init.constant_(m.weight, 1.0) |
| | nn.init.constant_(m.bias, 0) |
| | elif isinstance(m, nn.BatchNorm2d): |
| | nn.init.constant_(m.weight, 1.0) |
| | nn.init.constant_(m.bias, 0) |
| |
|
| | def forward_features_unpool(self, x): |
| | """ |
| | forward until avg pooling |
| | Args: |
| | x (_type_): input image tensor |
| | """ |
| | input_size = (x.size(2), x.size(3)) |
| | for conv, block in zip(self.convs, self.blocks): |
| | x, input_size = conv(x, input_size) |
| | if self.enable_checkpoint: |
| | x, input_size = checkpoint.checkpoint(block, x, input_size) |
| | else: |
| | x, input_size = block(x, input_size) |
| | return x |
| |
|
| | def forward_features(self, x): |
| | x = self.forward_features_unpool(x) |
| |
|
| | |
| | x = self.avgpool(x.transpose(1, 2)) |
| | |
| | x = torch.flatten(x, 1) |
| | x = self.norms(x) |
| |
|
| | return x |
| |
|
| | def forward(self, x): |
| | x = self.forward_features(x) |
| | x = self.head(x) |
| | return x |
| | |
| | @classmethod |
| | def from_config(cls, config): |
| | return cls( |
| | depths=config.depths, |
| | embed_dims=config.dim_embed, |
| | num_heads=config.num_heads, |
| | num_groups=config.num_groups, |
| | patch_size=config.patch_size, |
| | patch_stride=config.patch_stride, |
| | patch_padding=config.patch_padding, |
| | patch_prenorm=config.patch_prenorm, |
| | drop_path_rate=config.drop_path_rate, |
| | window_size=config.window_size, |
| | ) |
| |
|
| |
|
| |
|
| |
|
| | _CONFIG_FOR_DOC = "FeynModelConfig" |
| |
|
| | FEYNMODEL_START_DOCSTRING = r""" |
| | This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the |
| | library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads |
| | etc.) |
| | |
| | This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. |
| | Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage |
| | and behavior. |
| | |
| | Parameters: |
| | config ([`FeynModelConfig`]): |
| | Model configuration class with all the parameters of the model. Initializing with a config file does not |
| | load the weights associated with the model, only the configuration. Check out the |
| | [`~PreTrainedModel.from_pretrained`] method to load the model weights. |
| | """ |
| | FEYNMODEL_INPUTS_DOCSTRING = r""" |
| | Args: |
| | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): |
| | Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide |
| | it. |
| | |
| | Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and |
| | [`PreTrainedTokenizer.__call__`] for details. |
| | |
| | [What are input IDs?](../glossary#input-ids) |
| | attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): |
| | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: |
| | |
| | - 1 for tokens that are **not masked**, |
| | - 0 for tokens that are **masked**. |
| | |
| | [What are attention masks?](../glossary#attention-mask) |
| | |
| | Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and |
| | [`PreTrainedTokenizer.__call__`] for details. |
| | |
| | If `past_key_values` is used, optionally only the last `input_ids` have to be input (see |
| | `past_key_values`). |
| | |
| | If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] |
| | and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more |
| | information on the default strategy. |
| | |
| | - 1 indicates the head is **not masked**, |
| | - 0 indicates the head is **masked**. |
| | position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
| | Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, |
| | config.n_positions - 1]`. |
| | |
| | [What are position IDs?](../glossary#position-ids) |
| | past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): |
| | Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention |
| | blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` |
| | returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. |
| | |
| | Two formats are allowed: |
| | - a [`~cache_utils.Cache`] instance; |
| | - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of |
| | shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy |
| | cache format. |
| | |
| | The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the |
| | legacy cache format will be returned. |
| | |
| | If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't |
| | have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` |
| | of shape `(batch_size, sequence_length)`. |
| | inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): |
| | Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This |
| | is useful if you want more control over how to convert `input_ids` indices into associated vectors than the |
| | model's internal embedding lookup matrix. |
| | use_cache (`bool`, *optional*): |
| | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see |
| | `past_key_values`). |
| | output_attentions (`bool`, *optional*): |
| | Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned |
| | tensors for more detail. |
| | output_hidden_states (`bool`, *optional*): |
| | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for |
| | more detail. |
| | return_dict (`bool`, *optional*): |
| | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
| | cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): |
| | Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, |
| | this tensor is not affected by padding. It is used to update the cache in the correct position and to infer |
| | the complete sequence length. |
| | """ |
| |
|
| | |
| | def _prepare_4d_causal_attention_mask_with_cache_position( |
| | attention_mask: torch.Tensor, |
| | sequence_length: int, |
| | target_length: int, |
| | dtype: torch.dtype, |
| | device: torch.device, |
| | min_dtype: float, |
| | cache_position: torch.Tensor, |
| | batch_size: int, |
| | ): |
| | |
| | |
| | if attention_mask is not None and attention_mask.dim() == 4: |
| | |
| | |
| | |
| | |
| | causal_mask = attention_mask[:, :, -sequence_length:, :] |
| | |
| | |
| | else: |
| | |
| | |
| | causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=torch.float32, device=device) |
| | |
| | if sequence_length != 1: |
| | causal_mask = torch.triu(causal_mask, diagonal=1) |
| | |
| | causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) |
| | causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) |
| | |
| | if attention_mask is not None: |
| | |
| | causal_mask = causal_mask.clone() |
| | mask_length = attention_mask.shape[-1] |
| | padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] |
| | padding_mask = padding_mask == 0 |
| | causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( |
| | padding_mask, min_dtype |
| | ) |
| | |
| |
|
| | return causal_mask |
| |
|
| | class LearnedAbsolutePositionEmbedding2D(nn.Module): |
| | """ |
| | This module learns positional embeddings up to a fixed maximum size. |
| | """ |
| |
|
| | def __init__(self, embedding_dim=256, num_pos=50): |
| | super().__init__() |
| | self.row_embeddings = nn.Embedding(num_pos, embedding_dim // 2) |
| | self.column_embeddings = nn.Embedding(num_pos, embedding_dim - (embedding_dim // 2)) |
| |
|
| | def forward(self, pixel_values): |
| | """ |
| | pixel_values: (batch_size, height, width, num_channels) |
| | returns: (batch_size, height, width, embedding_dim * 2) |
| | """ |
| | if len(pixel_values.shape) != 4: |
| | raise ValueError('pixel_values must be a 4D tensor') |
| | height, width = pixel_values.shape[1:3] |
| | width_values = torch.arange(width, device=pixel_values.device) |
| | height_values = torch.arange(height, device=pixel_values.device) |
| | x_emb = self.column_embeddings(width_values) |
| | y_emb = self.row_embeddings(height_values) |
| | |
| | pos = torch.cat([x_emb.unsqueeze(0).repeat(height, 1, 1), y_emb.unsqueeze(1).repeat(1, width, 1)], dim=-1) |
| | |
| | pos = pos.permute(2, 0, 1) |
| | pos = pos.unsqueeze(0) |
| | |
| | pos = pos.repeat(pixel_values.shape[0], 1, 1, 1) |
| | |
| | pos = pos.permute(0, 2, 3, 1) |
| | return pos |
| |
|
| | class PositionalEmbeddingCosine1D(nn.Module): |
| | """ |
| | This class implements a very simple positional encoding. It follows closely |
| | the encoder from the link below: |
| | https://pytorch.org/tutorials/beginner/translation_transformer.html |
| | Args: |
| | embed_dim: The dimension of the embeddings. |
| | dropout_prob: The dropout probability. |
| | max_seq_len: The maximum length to precompute the positional encodings. |
| | """ |
| | def __init__( |
| | self, |
| | embed_dim: int = 512, |
| | max_seq_len: int = 1024) -> None: |
| | super(PositionalEmbeddingCosine1D, self).__init__() |
| | self.embed_dim = embed_dim |
| | self.max_seq_len = max_seq_len |
| | |
| | factor = math.log(10000) |
| | denominator = torch.exp( |
| | -factor * torch.arange(0, self.embed_dim, 2) / self.embed_dim) |
| | |
| | |
| | frequencies = \ |
| | torch.arange(0, self.max_seq_len) \ |
| | .reshape(self.max_seq_len, 1) * denominator |
| | pos_idx_to_embed = torch.zeros((self.max_seq_len, self.embed_dim)) |
| | |
| | pos_idx_to_embed[:, 0::2] = torch.sin(frequencies) |
| | pos_idx_to_embed[:, 1::2] = torch.cos(frequencies) |
| | |
| | self.register_buffer("pos_idx_to_embed", pos_idx_to_embed) |
| |
|
| | def forward(self, seq_embeds: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Args: |
| | seq_embeds: The sequence embeddings in order. Allowed size: |
| | 1. [T, D], where T is the length of the sequence, and D is the |
| | frame embedding dimension. |
| | 2. [B, T, D], where B is the batch size and T and D are the |
| | same as above. |
| | Returns a tensor of with the same dimensions as the input: i.e., |
| | [1, T, D] or [T, D]. |
| | """ |
| | shape_len = len(seq_embeds.shape) |
| | assert 2 <= shape_len <= 3 |
| | len_seq = seq_embeds.size(-2) |
| | assert len_seq <= self.max_seq_len |
| | pos_embeds = self.pos_idx_to_embed[0:seq_embeds.size(-2), :] |
| | |
| | if shape_len == 3: |
| | pos_embeds = pos_embeds.view( |
| | (1, pos_embeds.size(0), pos_embeds.size(1))) |
| | return pos_embeds |
| |
|
| |
|
| | class LearnedAbsolutePositionEmbedding1D(nn.Module): |
| | """ |
| | Learnable absolute positional embeddings for 1D sequences. |
| | Args: |
| | embed_dim: The dimension of the embeddings. |
| | max_seq_len: The maximum length to precompute the positional encodings. |
| | """ |
| | def __init__( |
| | self, |
| | embedding_dim: int = 512, |
| | num_pos: int = 1024) -> None: |
| | super(LearnedAbsolutePositionEmbedding1D, self).__init__() |
| | self.embeddings = nn.Embedding(num_pos, embedding_dim) |
| | self.num_pos = num_pos |
| |
|
| | def forward(self, seq_embeds: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Args: |
| | seq_embeds: The sequence embeddings in order. Allowed size: |
| | 1. [T, D], where T is the length of the sequence, and D is the |
| | frame embedding dimension. |
| | 2. [B, T, D], where B is the batch size and T and D are the |
| | same as above. |
| | Returns a tensor of with the same dimensions as the input: i.e., |
| | [1, T, D] or [T, D]. |
| | """ |
| | shape_len = len(seq_embeds.shape) |
| | assert 2 <= shape_len <= 3 |
| | len_seq = seq_embeds.size(-2) |
| | assert len_seq <= self.num_pos |
| | |
| | pos_embeds = self.embeddings(torch.arange(len_seq).to(seq_embeds.device)) |
| | |
| | if shape_len == 3: |
| | pos_embeds = pos_embeds.view( |
| | (1, pos_embeds.size(0), pos_embeds.size(1))) |
| | return pos_embeds |
| |
|
| | def create_git_attention_mask( |
| | tgt: torch.Tensor, |
| | memory: torch.Tensor, |
| | max_length: int |
| | ) -> torch.Tensor: |
| | |
| | batch_size = tgt.size(0) |
| | num_tgt = tgt.shape[1] |
| | num_memory = memory.shape[1] |
| | total_length = num_memory + num_tgt |
| |
|
| | |
| | top_left = torch.zeros((num_memory, num_memory)) |
| | top_right = torch.full((num_memory, num_tgt), float(-3.4028e+38)) |
| |
|
| | |
| | bottom_left = torch.zeros((num_tgt, num_memory)) |
| |
|
| | |
| | bottom_right = torch.tril(torch.ones(num_tgt, num_tgt)) |
| |
|
| | |
| | bottom_right = bottom_right.masked_fill(bottom_right == 0, float(-3.4028e+38)) |
| | bottom_right = bottom_right.masked_fill(bottom_right == 1, float(0)) |
| |
|
| | |
| | left = torch.cat((top_left, bottom_left), dim=0) |
| | right = torch.cat((top_right, bottom_right), dim=0) |
| |
|
| | |
| | full_attention_mask = torch.cat((left, right), dim=1) |
| |
|
| | |
| | padding = torch.full((total_length, max_length - total_length), float(-3.4028e+38)) |
| | full_attention_mask = torch.cat((full_attention_mask, padding), dim=1) |
| |
|
| | |
| | full_attention_mask = full_attention_mask[None, None, :, :] |
| |
|
| | |
| | full_attention_mask = full_attention_mask.expand(batch_size, 1, full_attention_mask.size(-2), full_attention_mask.size(-1)) |
| |
|
| | return full_attention_mask |
| |
|
| | def get_position_ids_from_binary_attention_mask(mask): |
| | """ |
| | Extract position IDs from a binary attention mask. |
| | |
| | Args: |
| | mask (torch.Tensor): The attention mask tensor of shape (1, 1, seq_len, seq_len), |
| | where 1 indicates allowed attention and 0 indicates blocked attention. |
| | |
| | Returns: |
| | list: A list of lists where each sublist contains the allowed position IDs for each query position. |
| | """ |
| | |
| | _, _, seq_len, _ = mask.shape |
| | |
| | |
| | position_ids = torch.arange(seq_len, dtype=torch.long, device=mask.device) |
| | |
| | |
| | position_ids = position_ids.unsqueeze(0) |
| | |
| | return position_ids |
| |
|
| | def ensure_tensor(variable): |
| | |
| | if isinstance(variable, torch.Tensor): |
| | |
| | return variable |
| | else: |
| | |
| | try: |
| | |
| | tensor = torch.tensor(variable) |
| | |
| | return tensor |
| | except Exception as e: |
| | print(f"Error converting to tensor: {e}") |
| | raise |
| |
|
| | @add_start_docstrings( |
| | "The bare Model outputting raw hidden-states without any specific head on top.", |
| | FEYNMODEL_START_DOCSTRING, |
| | ) |
| | class FeynModel(Gemma2Model): |
| | """ |
| | Transformer decoder consisting of *config.num_hidden_layers* layers. |
| | Each layer is a [`FeynModelDecoderLayer`] + ['LoraLayer'] for *proj* moduls |
| | NB : LoraLayers will be added and activatd on proj modules onpy if pixel_values is not None |
| | |
| | Args: |
| | config: FeynModelConfig |
| | """ |
| |
|
| | def __init__(self, config: FeynModelConfig): |
| | super().__init__(config) |
| | |
| | self.mode='llm' |
| | ''' |
| | self.image_patch_tokens = int( |
| | (config.vision_config.image_size / config.vision_config.patch_size) ** 2 + 1 |
| | ) |
| | |
| | if config.num_image_with_embedding is not None: |
| | self.image_patch_tokens *= config.num_image_with_embedding |
| | ''' |
| | self.image_patch_tokens = 577 |
| | self.post_init() |
| |
|
| | def get_input_embeddings(self): |
| | return self.embed_tokens |
| |
|
| | def set_input_embeddings(self, value): |
| | self.embed_tokens = value |
| |
|
| | |
| | |
| | |
| | @add_start_docstrings_to_model_forward(FEYNMODEL_INPUTS_DOCSTRING) |
| | def forward( |
| | self, |
| | input_ids: torch.LongTensor = None, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | position_ids: Optional[torch.LongTensor] = None, |
| | past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, |
| | inputs_embeds: Optional[torch.FloatTensor] = None, |
| | use_cache: Optional[bool] = None, |
| | output_attentions: Optional[bool] = None, |
| | output_hidden_states: Optional[bool] = None, |
| | return_dict: Optional[bool] = None, |
| | cache_position: Optional[torch.LongTensor] = None, |
| | causal_attention_mask: Optional[torch.Tensor] = None, |
| | **kwargs, |
| | ) -> Union[Tuple, BaseModelOutputWithPast]: |
| | |
| | |
| | |
| | |
| | |
| | if cache_position is None: |
| | batch_size = input_ids.size(0) if input_ids is not None else inputs_embeds.size(0) |
| | cache_position = torch.zeros((batch_size,), dtype=torch.long, device=input_ids.device if input_ids is not None else inputs_embeds.device) |
| |
|
| |
|
| | |
| | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| | output_hidden_states = ( |
| | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| | ) |
| | use_cache = use_cache if use_cache is not None else self.config.use_cache |
| | return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| |
|
| | if (input_ids is None) ^ (inputs_embeds is not None): |
| | raise ValueError( |
| | "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" |
| | ) |
| |
|
| | if self.gradient_checkpointing and self.training and use_cache: |
| | logger.warning_once( |
| | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." |
| | ) |
| | use_cache = False |
| |
|
| | if inputs_embeds is None: |
| | inputs_embeds = self.embed_tokens(input_ids) |
| | causal_mask = self._update_causal_mask( |
| | attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions |
| | ) |
| | else: |
| | causal_mask = ensure_tensor(causal_attention_mask) |
| | position_ids = get_position_ids_from_binary_attention_mask(attention_mask) |
| | |
| | |
| |
|
| | if cache_position is None: |
| | cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device) |
| |
|
| | if position_ids is None : |
| | position_ids = cache_position.unsqueeze(0) |
| |
|
| | |
| | |
| | |
| | if not isinstance(position_ids, torch.Tensor): |
| | |
| | position_ids = torch.tensor(position_ids, dtype=torch.long, device=inputs_embeds.device) |
| | |
| | |
| | |
| | hidden_states = inputs_embeds |
| |
|
| | |
| | |
| | |
| | normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype) |
| | hidden_states = hidden_states * normalizer |
| |
|
| | all_hidden_states = () if output_hidden_states else None |
| | all_self_attns = () if output_attentions else None |
| | |
| | for decoder_layer in self.layers: |
| | if output_hidden_states: |
| | all_hidden_states += (hidden_states,) |
| |
|
| | if self.gradient_checkpointing and self.training: |
| | layer_outputs = self._gradient_checkpointing_func( |
| | decoder_layer.__call__, |
| | hidden_states, |
| | causal_mask, |
| | position_ids, |
| | past_key_values, |
| | output_attentions, |
| | use_cache, |
| | cache_position, |
| | ) |
| | else: |
| | layer_outputs = decoder_layer( |
| | hidden_states, |
| | attention_mask=causal_mask, |
| | position_ids=position_ids, |
| | past_key_value=past_key_values, |
| | output_attentions=output_attentions, |
| | use_cache=use_cache, |
| | cache_position=cache_position, |
| | ) |
| |
|
| | hidden_states = layer_outputs[0] |
| |
|
| | if output_attentions: |
| | all_self_attns += (layer_outputs[1],) |
| | |
| | hidden_states = self.norm(hidden_states) |
| |
|
| | |
| | if output_hidden_states: |
| | all_hidden_states += (hidden_states,) |
| |
|
| | next_cache = past_key_values if use_cache else None |
| | |
| | if not return_dict: |
| | return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) |
| | return BaseModelOutputWithPast( |
| | last_hidden_state=hidden_states, |
| | past_key_values=next_cache, |
| | hidden_states=all_hidden_states, |
| | attentions=all_self_attns, |
| | ) |
| | |
| |
|
| |
|
| | def _update_causal_mask( |
| | self, |
| | attention_mask: torch.Tensor, |
| | input_tensor: torch.Tensor, |
| | cache_position: torch.Tensor, |
| | past_key_values: Cache, |
| | output_attentions: bool, |
| | ): |
| |
|
| | |
| | |
| | |
| | |
| | |
| | if self.config._attn_implementation == "flash_attention_2": |
| | return attention_mask |
| |
|
| | dtype, device = input_tensor.dtype, input_tensor.device |
| | min_dtype = torch.finfo(dtype).min |
| | sequence_length = input_tensor.shape[1] |
| | if isinstance(past_key_values, HybridCache): |
| | target_length = past_key_values.get_max_length() |
| | else: |
| | target_length = attention_mask.shape[-1] if attention_mask is not None else input_tensor.shape[1] |
| |
|
| | |
| | causal_mask = _prepare_4d_causal_attention_mask_with_cache_position( |
| | attention_mask, |
| | sequence_length=sequence_length, |
| | target_length=target_length, |
| | dtype=dtype, |
| | device=device, |
| | min_dtype=min_dtype, |
| | cache_position=cache_position, |
| | batch_size=input_tensor.shape[0], |
| | ) |
| | |
| | return causal_mask |
| |
|
| |
|
| |
|
| | class FeynModelForCausalLM(Gemma2ForCausalLM): |
| | _tied_weights_keys = ["lm_head.weight"] |
| | config_class = FeynModelConfig |
| | def __init__(self, config): |
| | super().__init__(config) |
| | config.vision_config=Florence2VisionConfig.from_dict(config.vision_config) |
| | self.model = FeynModel(config) |
| | |
| | |
| | self.vision_tower = DaViT.from_config(config=config.vision_config) |
| | self._build_image_projection_layers(config) |
| |
|
| | self.__causal_attention_mask = None |
| | |
| | |
| | self.post_init() |
| |
|
| | |
| | def _build_image_projection_layers(self, config): |
| | image_dim_out = config.vision_config.dim_embed[-1] |
| | dim_projection = config.vision_config.projection_dim |
| | self.image_projection = nn.Parameter( |
| | torch.empty(image_dim_out, dim_projection) |
| | ) |
| | self.image_proj_norm = nn.LayerNorm(dim_projection) |
| | image_pos_embed_config = config.vision_config.image_pos_embed |
| | if image_pos_embed_config['type'] == 'learned_abs_2d': |
| | self.image_pos_embed = LearnedAbsolutePositionEmbedding2D( |
| | embedding_dim=image_dim_out, |
| | num_pos=image_pos_embed_config['max_pos_embeddings'] |
| | ) |
| | else: |
| | raise NotImplementedError('Not implemented yet') |
| |
|
| | self.image_feature_source = config.vision_config.image_feature_source |
| |
|
| | |
| | visual_temporal_embedding_config = config.vision_config.visual_temporal_embedding |
| | if visual_temporal_embedding_config['type'] == 'COSINE': |
| | self.visual_temporal_embed = PositionalEmbeddingCosine1D( |
| | embed_dim=image_dim_out, |
| | max_seq_len=visual_temporal_embedding_config['max_temporal_embeddings'] |
| | ) |
| | else: |
| | raise NotImplementedError('Not implemented yet') |
| |
|
| | |
| |
|
| | def _merge_input_ids_with_image_features(self, image_features, inputs_embeds): |
| | batch_size, image_token_length = image_features.size()[:-1] |
| | device = image_features.device |
| | image_attention_mask = torch.ones(batch_size, image_token_length, device=device) |
| |
|
| | if inputs_embeds is None: |
| | return image_features, image_attention_mask |
| |
|
| | task_prefix_embeds = inputs_embeds |
| | task_prefix_attention_mask = torch.ones(batch_size, task_prefix_embeds.size(1), device=device) |
| |
|
| | |
| | if len(task_prefix_attention_mask.shape) == 3: |
| | task_prefix_attention_mask = task_prefix_attention_mask.squeeze(1) |
| |
|
| | |
| | if image_features.size(0) != task_prefix_embeds.size(0): |
| | raise ValueError("Batch sizes of image_features and task_prefix_embeds do not match") |
| |
|
| | |
| | if image_features.dim() < task_prefix_embeds.dim(): |
| | image_features = image_features.unsqueeze(-1) |
| | elif task_prefix_embeds.dim() < image_features.dim(): |
| | task_prefix_embeds = task_prefix_embeds.unsqueeze(-1) |
| |
|
| | |
| | if image_features.size(2) != task_prefix_embeds.size(2): |
| | |
| | raise ValueError("Internal dimensions of image_features and task_prefix_embeds do not match") |
| |
|
| | inputs_embeds = torch.cat([image_features, task_prefix_embeds], dim=1) |
| | attention_mask = torch.cat([image_attention_mask, task_prefix_attention_mask], dim=1) |
| |
|
| | return inputs_embeds, attention_mask |
| |
|
| | def _encode_image(self, pixel_values): |
| | if len(pixel_values.shape) == 4: |
| | batch_size, C, H, W = pixel_values.shape |
| | T = 1 |
| | x = self.vision_tower.forward_features_unpool(pixel_values) |
| | else: |
| | |
| | pixel_values = pixel_values.unsqueeze(0) |
| | batch_size, C, H, W = pixel_values.shape |
| | T = 1 |
| | x = self.vision_tower.forward_features_unpool(pixel_values) |
| | |
| | if self.image_pos_embed is not None: |
| | x = x.view(batch_size * T, -1, x.shape[-1]) |
| | num_tokens = x.shape[-2] |
| | h, w = int(num_tokens ** 0.5), int(num_tokens ** 0.5) |
| | assert h * w == num_tokens, 'only support square feature maps for now' |
| | x = x.view(batch_size * T, h, w, x.shape[-1]) |
| | pos_embed = self.image_pos_embed(x) |
| | x = x + pos_embed |
| | x = x.view(batch_size, T * h*w, x.shape[-1]) |
| |
|
| | if self.visual_temporal_embed is not None: |
| | visual_temporal_embed = self.visual_temporal_embed(x.view(batch_size, T, -1, x.shape[-1])[:, :, 0]) |
| | x = x.view(batch_size, T, -1, x.shape[-1]) + visual_temporal_embed.view(1, T, 1, x.shape[-1]) |
| |
|
| | x_feat_dict = {} |
| |
|
| | spatial_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=2) |
| | x_feat_dict['spatial_avg_pool'] = spatial_avg_pool_x |
| |
|
| | temporal_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=1) |
| | x_feat_dict['temporal_avg_pool'] = temporal_avg_pool_x |
| |
|
| | x = x.view(batch_size, T, -1, x.shape[-1])[:, -1] |
| | x_feat_dict['last_frame'] = x |
| |
|
| | new_x = [] |
| | for _image_feature_source in self.image_feature_source: |
| | if _image_feature_source not in x_feat_dict: |
| | raise ValueError('invalid image feature source: {}'.format(_image_feature_source)) |
| | new_x.append(x_feat_dict[_image_feature_source]) |
| |
|
| | x = torch.cat(new_x, dim=1) |
| |
|
| | x = x @ self.image_projection |
| | x = self.image_proj_norm(x) |
| |
|
| | return x |
| | |
| |
|
| | def get_input_embeddings(self): |
| | return self.model.embed_tokens |
| |
|
| | def set_input_embeddings(self, value): |
| | self.model.embed_tokens = value |
| |
|
| | def get_output_embeddings(self): |
| | return self.lm_head |
| |
|
| | def set_output_embeddings(self, new_embeddings): |
| | self.lm_head = new_embeddings |
| |
|
| | def set_decoder(self, decoder): |
| | self.model = decoder |
| |
|
| | def get_decoder(self): |
| | return self.model |
| |
|
| | @add_start_docstrings_to_model_forward(FEYNMODEL_INPUTS_DOCSTRING) |
| | @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) |
| | def forward( |
| | self, |
| | input_ids: torch.LongTensor = None, |
| | pixel_values: Optional[torch.Tensor] = None, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | position_ids: Optional[torch.LongTensor] = None, |
| | past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, |
| | inputs_embeds: Optional[torch.FloatTensor] = None, |
| | labels: Optional[torch.LongTensor] = None, |
| | use_cache: Optional[bool] = None, |
| | output_attentions: Optional[bool] = None, |
| | output_hidden_states: Optional[bool] = None, |
| | return_dict: Optional[bool] = None, |
| | cache_position: Optional[torch.LongTensor] = None, |
| | **kwargs, |
| | ) -> Union[Tuple, CausalLMOutputWithPast]: |
| | r""" |
| | Args: |
| | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
| | Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., |
| | config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored |
| | (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. |
| | |
| | Returns: |
| | |
| | Example: |
| | |
| | ```python |
| | >>> from transformers import AutoTokenizer, GemmaForCausalLM |
| | |
| | >>> model = GemmaForCausalLM.from_pretrained("google/gemma-2-9b") |
| | >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b") |
| | |
| | >>> prompt = "What is your favorite condiment?" |
| | >>> inputs = tokenizer(prompt, return_tensors="pt") |
| | |
| | >>> # Generate |
| | >>> generate_ids = model.generate(inputs.input_ids, max_length=30) |
| | >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] |
| | "What is your favorite condiment?" |
| | ```""" |
| |
|
| | |
| | if self.training and self.config._attn_implementation != "eager": |
| | logger.warning_once( |
| | "It is strongly recommended to train FeynModel models with the `eager` attention implementation " |
| | f"instead of `{self.config._attn_implementation}`. Use `eager` with `AutoModelForCausalLM.from_pretrained('<path-to-checkpoint>', attn_implementation='eager')`." |
| | ) |
| | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| | output_hidden_states = ( |
| | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| | ) |
| | return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| |
|
| | if pixel_values is not None: |
| | self.model.mode='vlm' |
| | |
| | if input_ids is not None: |
| | inputs_embeds = self.get_input_embeddings()(input_ids) |
| | image_features = self._encode_image(pixel_values) |
| | inputs_embeds, causal_attention_mask = self._merge_input_ids_with_image_features(image_features, inputs_embeds ) |
| | causal_attention_mask = create_git_attention_mask(tgt=input_ids, memory=image_features,max_length=8192) |
| | causal_attention_mask=causal_attention_mask.to(input_ids.device) |
| | self.__causal_attention_mask=causal_attention_mask |
| |
|
| | |
| | if pixel_values is not None: |
| | outputs = self.model( |
| | input_ids=None, |
| | attention_mask=causal_attention_mask, |
| | position_ids=position_ids, |
| | past_key_values=past_key_values, |
| | inputs_embeds=inputs_embeds, |
| | use_cache=use_cache, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | return_dict=return_dict, |
| | cache_position=cache_position, |
| | causal_attention_mask=causal_attention_mask, |
| | ) |
| | else: |
| | outputs = self.model( |
| | input_ids=input_ids, |
| | attention_mask=attention_mask, |
| | position_ids=position_ids, |
| | past_key_values=past_key_values, |
| | inputs_embeds=inputs_embeds, |
| | use_cache=use_cache, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | return_dict=return_dict, |
| | cache_position=cache_position, |
| | causal_attention_mask=self.__causal_attention_mask, |
| | ) |
| | |
| | |
| | hidden_states = outputs[0] |
| | logits = self.lm_head(hidden_states) |
| | |
| | if self.config.final_logit_softcapping is not None: |
| | logits = logits / self.config.final_logit_softcapping |
| | logits = torch.tanh(logits) |
| | logits = logits * self.config.final_logit_softcapping |
| | |
| |
|
| | logits = logits.float() |
| | loss = None |
| | if labels is not None: |
| | |
| | num_image_tokens = self.model.image_patch_tokens |
| | shifted_logits = logits[:, num_image_tokens:-1, :].contiguous() |
| | labels = labels[:, 1:].contiguous() |
| | loss_fct = CrossEntropyLoss() |
| | loss = loss_fct(shifted_logits.view(-1, self.config.vocab_size), labels.view(-1)) |
| | |
| | if not return_dict: |
| | |
| | output = (logits,) + outputs[1:] |
| | return (loss,) + output if loss is not None else output |
| | |
| | return CausalLMOutputWithPast( |
| | loss=loss, |
| | logits=logits, |
| | past_key_values=outputs.past_key_values, |
| | hidden_states=outputs.hidden_states, |
| | attentions=outputs.attentions, |
| | ) |
| |
|
| | def prepare_inputs_for_generation( |
| | self, |
| | input_ids, |
| | past_key_values=None, |
| | attention_mask=None, |
| | inputs_embeds=None, |
| | cache_position=None, |
| | position_ids=None, |
| | use_cache=True, |
| | **kwargs, |
| | ): |
| | |
| | |
| | |
| | |
| | |
| | if past_key_values is not None: |
| | if inputs_embeds is not None: |
| | input_ids = input_ids[:, -cache_position.shape[0] :] |
| | elif input_ids.shape[1] != cache_position.shape[0]: |
| | input_ids = input_ids[:, cache_position] |
| |
|
| | if attention_mask is not None and position_ids is None: |
| | |
| | position_ids = attention_mask.long().cumsum(-1) - 1 |
| | position_ids.masked_fill_(attention_mask == 0, 1) |
| | if past_key_values: |
| | |
| | position_ids = position_ids[:, -input_ids.shape[1] :] |
| | |
| | |
| | |
| | |
| | |
| | position_ids = position_ids.clone(memory_format=torch.contiguous_format) |
| | |
| |
|
| | |
| | if inputs_embeds is not None and cache_position[0] == 0: |
| | |
| | model_inputs = {"inputs_embeds": inputs_embeds} |
| | else: |
| | |
| | |
| | model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format)} |
| |
|
| | if isinstance(past_key_values, HybridCache) and attention_mask.ndim == 2: |
| | if inputs_embeds is not None and input_ids.size(1)!= 0 : |
| | |
| | batch_size, sequence_length, _ = inputs_embeds.shape |
| | device = inputs_embeds.device |
| | |
| | else: |
| | batch_size, sequence_length = position_ids.shape |
| | device = input_ids.device |
| | |
| |
|
| | |
| | |
| | if hasattr(self.lm_head, 'weight'): |
| | |
| | if isinstance(self.lm_head.weight, torch.Tensor): |
| | dtype = self.lm_head.weight.dtype |
| | elif callable(self.lm_head.weight): |
| | dtype = self.lm_head.weight().dtype |
| | else: |
| | raise TypeError(f"Type inattendu pour self.lm_head.weight : {type(self.lm_head.weight)}") |
| | |
| | |
| | |
| | if isinstance(self.lm_head, torch.ao.nn.quantized.dynamic.Linear): |
| | |
| | weight, bias = self.lm_head._weight_bias() |
| | dtype = weight.dtype |
| | else: |
| | dtype = self.lm_head.weight.dtype |
| |
|
| | |
| | if torch.is_floating_point(torch.empty(0, dtype=dtype)): |
| | |
| | min_dtype = torch.finfo(torch.float32).min |
| | else: |
| | min_dtype = torch.iinfo(dtype).min |
| | |
| |
|
| |
|
| | attention_mask = _prepare_4d_causal_attention_mask_with_cache_position( |
| | attention_mask, |
| | sequence_length=sequence_length, |
| | target_length=past_key_values.get_max_length(), |
| | dtype=dtype, |
| | device=device, |
| | min_dtype=min_dtype, |
| | cache_position=cache_position, |
| | batch_size=batch_size, |
| | ) |
| | |
| |
|
| | model_inputs.update( |
| | { |
| | "position_ids": position_ids, |
| | "cache_position": cache_position, |
| | "past_key_values": past_key_values, |
| | "use_cache": use_cache, |
| | "attention_mask": attention_mask, |
| | } |
| | ) |
| | return model_inputs |
| | |
| | def generate( |
| | self, |
| | input_ids, |
| | pixel_values=None, |
| | max_length=None, |
| | do_sample=True, |
| | temperature=0.7, |
| | **kwargs |
| | ): |
| | |
| |
|
| | if pixel_values is not None: |
| | if input_ids is not None: |
| | |
| | inputs_embeds = self.get_input_embeddings()(input_ids) |
| | |
| | image_features = self._encode_image(pixel_values) |
| | inputs_embeds, causal_attention_mask = self._merge_input_ids_with_image_features(image_features, inputs_embeds ) |
| | causal_attention_mask = create_git_attention_mask(tgt=input_ids, memory=image_features,max_length=max_length) |
| | causal_attention_mask=causal_attention_mask.to(input_ids.device) |
| | self.__causal_attention_mask=causal_attention_mask |
| | self.model.mode='vlm' |
| | result = super().generate( |
| | input_ids=None, |
| | inputs_embeds=inputs_embeds, |
| | max_length=max_length, |
| | do_sample=do_sample, |
| | temperature=temperature, |
| | **kwargs |
| | ) |
| | |
| | else: |
| | |
| | self.model.mode=='llm' |
| | result = super().generate( |
| | input_ids=input_ids, |
| | |
| | max_length=max_length, |
| | do_sample=do_sample, |
| | temperature=temperature, |
| | **kwargs |
| | ) |
| | self.__causal_attention_mask = None |
| |
|
| | return result |
| | |
| |
|