# Originally from OpenCLIP (https://github.com/mlfoundations/open_clip) from collections import OrderedDict import math from typing import Callable, Optional, Type, Union import torch from torch import nn from torch.nn import functional as F from torch.utils.checkpoint import checkpoint class LayerNormFp32(nn.LayerNorm): """Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back).""" def forward(self, x: torch.Tensor): orig_type = x.dtype x = F.layer_norm( x.to(torch.float32), self.normalized_shape, self.weight, self.bias, self.eps ) return x.to(orig_type) class LayerNorm(nn.LayerNorm): """Subclass torch's LayerNorm (with cast back to input dtype).""" def forward(self, x: torch.Tensor): orig_type = x.dtype x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) return x.to(orig_type) class QuickGELU(nn.Module): # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory def forward(self, x: torch.Tensor): return x * torch.sigmoid(1.702 * x) class LayerScale(nn.Module): def __init__(self, dim, init_values=1e-5, inplace=False): super().__init__() self.inplace = inplace self.gamma = nn.Parameter(init_values * torch.ones(dim)) def forward(self, x): return x.mul_(self.gamma) if self.inplace else x * self.gamma class Attention(nn.Module): def __init__( self, dim: int, num_heads: int = 8, qkv_bias: bool = True, qk_norm: bool = False, scaled_cosine: bool = False, scale_heads: bool = False, inner_norm: bool = False, logit_scale_max: float = math.log(1.0 / 0.01), norm_layer: Type[nn.Module] = LayerNormFp32, attn_drop: float = 0.0, proj_drop: float = 0.0, ): super().__init__() assert not (scaled_cosine and qk_norm), ( "Cannot activate both scaled cosine and QK normalization" ) self.scaled_cosine = scaled_cosine self.scale_heads = scale_heads assert dim % num_heads == 0, "dim should be divisible by num_heads" self.num_heads = num_heads self.head_dim = dim // num_heads self.scale = self.head_dim**-0.5 self.logit_scale_max = logit_scale_max self.use_fsdpa = hasattr(nn.functional, "scaled_dot_product_attention") self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale) if qkv_bias: self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3)) else: self.in_proj_bias = None if qk_norm: self.ln_q = norm_layer(self.head_dim) self.ln_k = norm_layer(self.head_dim) else: self.ln_q = nn.Identity() self.ln_k = nn.Identity() if self.scaled_cosine: self.logit_scale = nn.Parameter( torch.log(10 * torch.ones((num_heads, 1, 1))) ) else: self.logit_scale = None self.attn_drop = nn.Dropout(attn_drop) if self.scale_heads: self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1))) else: self.head_scale = None if inner_norm: self.ln_inner = norm_layer(dim) else: self.ln_inner = nn.Identity() self.out_proj = nn.Linear(dim, dim) self.out_drop = nn.Dropout(proj_drop) def forward(self, x, attn_mask: Optional[torch.Tensor] = None): N, L, C = x.shape q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1) q = q.reshape(N, L, self.num_heads, -1).transpose(1, 2) k = k.reshape(N, L, self.num_heads, -1).transpose(1, 2) v = v.reshape(N, L, self.num_heads, -1).transpose(1, 2) if attn_mask is not None: if attn_mask.ndim == 3: attn_mask = attn_mask.reshape(N, self.num_heads, L, L) if attn_mask.dtype == torch.bool: new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype) new_attn_mask.masked_fill_(attn_mask, float("-inf")) attn_mask = new_attn_mask else: attn_mask = attn_mask.to(dtype=q.dtype) if self.logit_scale is not None: attn = torch.bmm( F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2) ) logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp() attn = attn * logit_scale if attn_mask is not None: attn = attn + attn_mask attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = torch.bmm(attn, v) else: q = self.ln_q(q) k = self.ln_k(k) if self.use_fsdpa: x = F.scaled_dot_product_attention( q, k, v, attn_mask=attn_mask, dropout_p=self.attn_drop.p if self.training else 0.0, ) else: q = q * self.scale attn = torch.bmm(q, k.transpose(-1, -2)) if attn_mask is not None: attn += attn_mask attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = torch.bmm(attn, v) if self.head_scale is not None: x = x * self.head_scale x = x.transpose(1, 2).reshape(N, L, C) x = self.ln_inner(x) x = self.out_proj(x) x = self.out_drop(x) return x class ResidualAttentionBlock(nn.Module): def __init__( self, d_model: int, n_head: int, mlp_ratio: float = 4.0, ls_init_value: float = None, act_layer: Callable = nn.GELU, norm_layer: Callable = LayerNorm, is_cross_attention: bool = False, batch_first: bool = True, ): super().__init__() self.ln_1 = norm_layer(d_model) self.attn = nn.MultiheadAttention(d_model, n_head, batch_first=batch_first) self.ls_1 = ( LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() ) if is_cross_attention: self.ln_1_kv = norm_layer(d_model) self.ln_2 = norm_layer(d_model) mlp_width = int(d_model * mlp_ratio) self.mlp = nn.Sequential( OrderedDict( [ ("c_fc", nn.Linear(d_model, mlp_width)), ("gelu", act_layer()), ("c_proj", nn.Linear(mlp_width, d_model)), ] ) ) self.ls_2 = ( LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() ) def get_weight_dtype(self) -> torch.dtype: if hasattr(self.mlp.c_fc, "int8_original_dtype"): return self.mlp.c_fc.int8_original_dtype return self.mlp.c_fc.weight.dtype def attention( self, q_x: torch.Tensor, k_x: Optional[torch.Tensor] = None, v_x: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None, key_padding_mask: Optional[torch.Tensor] = None, ): k_x = k_x if k_x is not None else q_x v_x = v_x if v_x is not None else q_x attn_mask = attn_mask.to(q_x.dtype) if attn_mask is not None else None return self.attn( q_x, k_x, v_x, need_weights=False, attn_mask=attn_mask, key_padding_mask=key_padding_mask, )[0] def forward( self, q_x: torch.Tensor, k_x: Optional[torch.Tensor] = None, v_x: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None, key_padding_mask: Optional[torch.Tensor] = None, ): k_x = ( self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None ) v_x = ( self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None ) x = q_x + self.ls_1( self.attention( q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask, key_padding_mask=key_padding_mask, ) ) x = x + self.ls_2(self.mlp(self.ln_2(x))) return x class CustomResidualAttentionBlock(nn.Module): def __init__( self, d_model: int, n_head: int, mlp_ratio: float = 4.0, ls_init_value: float = None, act_layer: Type[nn.Module] = nn.GELU, norm_layer: Type[nn.Module] = LayerNorm, qk_norm: bool = False, scale_cosine_attn: bool = False, scale_heads: bool = False, scale_attn_inner: bool = False, scale_attn: bool = False, scale_fc: bool = False, batch_first: bool = True, ): super().__init__() assert batch_first, "batch_first must be True for CustomResidualAttentionBlock" self.ln_1 = norm_layer(d_model) self.attn = Attention( d_model, n_head, qk_norm=qk_norm, scaled_cosine=scale_cosine_attn, scale_heads=scale_heads, inner_norm=scale_attn_inner, norm_layer=norm_layer, ) self.ln_attn = norm_layer(d_model) if scale_attn else nn.Identity() self.ls_1 = ( LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() ) self.ln_2 = norm_layer(d_model) mlp_width = int(d_model * mlp_ratio) self.mlp = nn.Sequential( OrderedDict( [ ("c_fc", nn.Linear(d_model, mlp_width)), ("gelu", act_layer()), ("ln", norm_layer(mlp_width) if scale_fc else nn.Identity()), ("c_proj", nn.Linear(mlp_width, d_model)), ] ) ) self.ls_2 = ( LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() ) def get_weight_dtype(self) -> torch.dtype: if hasattr(self.mlp.c_fc, "int8_original_dtype"): return self.mlp.c_fc.int8_original_dtype return self.mlp.c_fc.weight.dtype def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): x = x + self.ls_1(self.ln_attn(self.attn(self.ln_1(x), attn_mask=attn_mask))) x = x + self.ls_2(self.mlp(self.ln_2(x))) return x class Transformer(nn.Module): def __init__( self, width: int, layers: int, heads: int, mlp_ratio: float = 4.0, ls_init_value: float = None, act_layer: Type[nn.Module] = nn.GELU, norm_layer: Type[nn.Module] = LayerNorm, batch_first: bool = True, block_type: Optional[str] = None, qk_norm: bool = False, scaled_cosine_attn: bool = False, scale_heads: bool = False, scale_attn_inner: bool = False, scale_attn: bool = False, scale_fc: bool = False, ): super().__init__() self.width = width self.layers = layers self.batch_first = batch_first self.grad_checkpointing = False if block_type is None: if any( [ qk_norm, scaled_cosine_attn, scale_heads, scale_attn_inner, scale_attn, scale_fc, ] ): block_type = "custom" else: block_type = "default" if block_type == "custom": self.resblocks = nn.ModuleList( [ CustomResidualAttentionBlock( width, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer, qk_norm=qk_norm, scale_cosine_attn=scaled_cosine_attn, scale_heads=scale_heads, scale_attn_inner=scale_attn_inner, scale_attn=scale_attn, scale_fc=scale_fc, batch_first=batch_first, ) for _ in range(layers) ] ) else: self.resblocks = nn.ModuleList( [ ResidualAttentionBlock( width, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer, batch_first=batch_first, ) for _ in range(layers) ] ) def get_cast_dtype(self) -> torch.dtype: return self.resblocks[0].get_weight_dtype() def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): if not self.batch_first: x = x.transpose(0, 1).contiguous() for r in self.resblocks: if self.grad_checkpointing and not torch.jit.is_scripting(): x = checkpoint(r, x, None, None, attn_mask, use_reentrant=False) else: x = r(x, attn_mask=attn_mask) if not self.batch_first: x = x.transpose(0, 1) return x def _expand_token(token, batch_size: int): return token.view(1, 1, -1).expand(batch_size, -1, -1) def text_global_pool( x: torch.Tensor, text: Optional[torch.Tensor] = None, pool_type: str = "argmax", eos_token_id: Optional[int] = None, ) -> torch.Tensor: if pool_type == "first": pooled = x[:, 0] elif pool_type == "last": pooled = x[:, -1] elif pool_type == "argmax": assert text is not None pooled = x[torch.arange(x.shape[0], device=x.device), text.argmax(dim=-1)] elif pool_type == "eos": assert text is not None assert eos_token_id is not None idx = (text == eos_token_id).int().argmax(dim=-1) pooled = x[torch.arange(x.shape[0], device=x.device), idx] else: pooled = x return pooled class TextTransformer(nn.Module): output_tokens: torch.jit.Final[bool] def __init__( self, context_length: int = 77, vocab_size: int = 49408, width: int = 512, heads: int = 8, layers: int = 12, mlp_ratio: float = 4.0, ls_init_value: float = None, output_dim: Optional[int] = 512, embed_cls: bool = False, no_causal_mask: bool = False, use_pad_mask: bool = False, correct_cls_mask: bool = False, pad_id: int = 0, eos_id: int = 2, pool_type: str = "argmax", proj_type: str = "linear", proj_bias: bool = False, act_layer: Type[nn.Module] = nn.GELU, norm_layer: Type[nn.Module] = LayerNorm, output_tokens: bool = False, block_type: Optional[str] = None, qk_norm: bool = False, scaled_cosine_attn: bool = False, scale_heads: bool = False, scale_attn_inner: bool = False, scale_attn: bool = False, scale_fc: bool = False, ): super().__init__() assert pool_type in ("first", "last", "argmax", "eos", "none") self.output_tokens = output_tokens self.num_pos = self.context_length = context_length self.vocab_size = vocab_size self.width = width self.output_dim = output_dim self.heads = heads self.pad_id = pad_id self.eos_id = eos_id self.pool_type = pool_type self.use_pad_mask = use_pad_mask and no_causal_mask self.correct_cls_mask = correct_cls_mask self.token_embedding = nn.Embedding(vocab_size, width) if embed_cls: self.cls_emb = nn.Parameter(torch.empty(width)) self.num_pos += 1 else: self.cls_emb = None self.positional_embedding = nn.Parameter(torch.empty(self.num_pos, width)) self.transformer = Transformer( width=width, layers=layers, heads=heads, mlp_ratio=mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer, block_type=block_type, qk_norm=qk_norm, scaled_cosine_attn=scaled_cosine_attn, scale_heads=scale_heads, scale_attn_inner=scale_attn_inner, scale_attn=scale_attn, scale_fc=scale_fc, ) self.ln_final = norm_layer(width) if no_causal_mask: self.attn_mask = None else: self.register_buffer( "attn_mask", self.build_causal_mask(), persistent=False ) if proj_type == "none" or not output_dim: self.text_projection = None else: if proj_bias: self.text_projection = nn.Linear(width, output_dim) else: self.text_projection = nn.Parameter(torch.empty(width, output_dim)) self.init_parameters() def init_parameters(self): nn.init.normal_(self.token_embedding.weight, std=0.02) nn.init.normal_(self.positional_embedding, std=0.01) if self.cls_emb is not None: nn.init.normal_(self.cls_emb, std=0.01) proj_std = (self.transformer.width**-0.5) * ( (2 * self.transformer.layers) ** -0.5 ) attn_std = self.transformer.width**-0.5 fc_std = (2 * self.transformer.width) ** -0.5 for block in self.transformer.resblocks: nn.init.normal_(block.attn.in_proj_weight, std=attn_std) nn.init.normal_(block.attn.out_proj.weight, std=proj_std) nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) if self.text_projection is not None: if isinstance(self.text_projection, nn.Linear): nn.init.normal_( self.text_projection.weight, std=self.transformer.width**-0.5 ) if self.text_projection.bias is not None: nn.init.zeros_(self.text_projection.bias) else: nn.init.normal_(self.text_projection, std=self.transformer.width**-0.5) def build_causal_mask(self): mask = torch.empty(self.num_pos, self.num_pos) mask.fill_(float("-inf")) mask.triu_(1) return mask def _build_additive_mask(self, text, seq_len, dtype): valid = text != self.pad_id if self.cls_emb is not None: cls_valid = valid.new_ones(valid.size(0), 1) valid = torch.cat( [valid, cls_valid] if self.correct_cls_mask else [cls_valid, valid], 1 ) key_mask = valid.unsqueeze(1).expand(-1, seq_len, -1) additive = torch.zeros_like(key_mask, dtype=dtype) additive.masked_fill_(~key_mask, float("-inf")) additive = additive.repeat_interleave(self.heads, 0) return additive def _embeds(self, text): cast_dtype = self.transformer.get_cast_dtype() B, seq_len = text.shape x = self.token_embedding(text).to(cast_dtype) if self.cls_emb is not None: x = torch.cat([x, _expand_token(self.cls_emb, x.size(0))], 1) seq_len += 1 attn_mask = self.attn_mask if self.use_pad_mask or self.cls_emb is not None: add_mask = self._build_additive_mask(text, seq_len, x.dtype) if attn_mask is not None: attn_mask = attn_mask[:seq_len, :seq_len].unsqueeze(0) + add_mask else: attn_mask = add_mask x = x + self.positional_embedding[:seq_len].to(cast_dtype) return x, attn_mask def forward(self, text): x, attn_mask = self._embeds(text) x = self.transformer(x, attn_mask=attn_mask) if self.cls_emb is not None: pooled = text_global_pool(x, pool_type="last") pooled = self.ln_final(pooled) tokens = x[:, :-1] else: x = self.ln_final(x) pooled = text_global_pool( x, text, pool_type=self.pool_type, eos_token_id=getattr(self, "eos_id", None), ) tokens = x if self.text_projection is not None: if isinstance(self.text_projection, nn.Linear): pooled = self.text_projection(pooled) else: pooled = pooled @ self.text_projection if self.output_tokens: return pooled, tokens return pooled