ValentineKRAFTON's picture
initial commit
acd771b verified
# Originally from OpenCLIP (https://github.com/mlfoundations/open_clip)
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from functools import partial
from .timm_model import TimmModel
from .transformer import (
LayerNormFp32,
LayerNorm,
QuickGELU,
TextTransformer,
text_global_pool,
)
from .utils import to_2tuple
@dataclass
class CLIPVisionCfg:
layers: Union[Tuple[int, int, int, int], int] = 12
width: int = 768
head_width: int = 64
mlp_ratio: float = 4.0
patch_size: int = 16
image_size: Union[Tuple[int, int], int] = 224
ls_init_value: Optional[float] = None
patch_dropout: float = 0.0
attentional_pool: bool = False
attn_pooler_queries: int = 256
attn_pooler_heads: int = 8
no_ln_pre: bool = False
pos_embed_type: str = "learnable"
final_ln_after_pool: bool = False
pool_type: str = "tok"
output_tokens: bool = False
act_kwargs: Optional[dict] = None
norm_kwargs: Optional[dict] = None
block_type: Optional[str] = None
qk_norm: bool = False
scaled_cosine_attn: bool = False
scale_heads: bool = False
scale_attn_inner: bool = False
scale_attn: bool = False
scale_fc: bool = False
timm_model_name: Optional[str] = None
timm_model_pretrained: bool = False
timm_pool: str = "avg"
timm_proj: str = "linear"
timm_proj_bias: bool = False
timm_drop: float = 0.0
timm_drop_path: Optional[float] = None
timm_use_rope: bool = False
timm_rope_keep_ape: bool = False
timm_dynamic_img_size: bool = False
timm_norm_pre: bool = False
@dataclass
class CLIPTextCfg:
context_length: int = 77
vocab_size: int = 49408
hf_tokenizer_name: Optional[str] = None
tokenizer_mode: Optional[str] = None
tokenizer_kwargs: Optional[dict] = None
width: int = 512
heads: int = 8
layers: int = 12
mlp_ratio: float = 4.0
ls_init_value: Optional[float] = None
embed_cls: bool = False
pad_id: int = 0
eos_id: int = 2
no_causal_mask: bool = False
final_ln_after_pool: bool = False
pool_type: str = "argmax"
proj_bias: bool = False
proj_type: str = "linear"
output_tokens: bool = False
act_kwargs: dict = None
norm_kwargs: dict = None
block_type: Optional[str] = None
qk_norm: bool = False
scaled_cosine_attn: bool = False
scale_heads: bool = False
scale_attn_inner: bool = False
scale_attn: bool = False
scale_fc: bool = False
hf_model_name: Optional[str] = None
hf_model_pretrained: bool = True
hf_proj_type: str = "mlp"
hf_pooler_type: str = "mean_pooler"
def get_cast_dtype(precision: str):
cast_dtype = None
if precision == "bf16":
cast_dtype = torch.bfloat16
elif precision == "fp16":
cast_dtype = torch.float16
return cast_dtype
def _build_vision_tower(
embed_dim: int,
vision_cfg: CLIPVisionCfg,
quick_gelu: bool = False,
cast_dtype: Optional[torch.dtype] = None,
):
if isinstance(vision_cfg, dict):
vision_cfg = CLIPVisionCfg(**vision_cfg)
if not vision_cfg.timm_model_name:
raise ValueError(
"Only TimmModel-based vision towers are supported in raon-vision-encoder. "
"Please set timm_model_name in vision_cfg."
)
visual = TimmModel(
vision_cfg.timm_model_name,
pretrained=vision_cfg.timm_model_pretrained,
pool=vision_cfg.timm_pool,
proj=vision_cfg.timm_proj,
proj_bias=vision_cfg.timm_proj_bias,
drop=vision_cfg.timm_drop,
drop_path=vision_cfg.timm_drop_path,
patch_drop=vision_cfg.patch_dropout if vision_cfg.patch_dropout > 0 else None,
init_values=vision_cfg.ls_init_value,
qk_norm=vision_cfg.qk_norm,
use_rope=vision_cfg.timm_use_rope,
rope_keep_ape=vision_cfg.timm_rope_keep_ape,
dynamic_img_size=vision_cfg.timm_dynamic_img_size,
norm_pre=vision_cfg.timm_norm_pre,
embed_dim=embed_dim,
image_size=vision_cfg.image_size,
output_tokens=vision_cfg.output_tokens,
)
return visual
def _build_text_tower(
embed_dim: int,
text_cfg: CLIPTextCfg,
quick_gelu: bool = False,
cast_dtype: Optional[torch.dtype] = None,
):
if isinstance(text_cfg, dict):
text_cfg = CLIPTextCfg(**text_cfg)
act_layer = QuickGELU if quick_gelu else nn.GELU
norm_layer = (
LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
)
if text_cfg.norm_kwargs:
norm_layer = partial(norm_layer, **text_cfg.norm_kwargs)
if text_cfg.act_kwargs is not None:
act_layer = partial(act_layer, **text_cfg.act_kwargs)
text = TextTransformer(
context_length=text_cfg.context_length,
vocab_size=text_cfg.vocab_size,
width=text_cfg.width,
heads=text_cfg.heads,
layers=text_cfg.layers,
mlp_ratio=text_cfg.mlp_ratio,
ls_init_value=text_cfg.ls_init_value,
output_dim=embed_dim,
embed_cls=text_cfg.embed_cls,
no_causal_mask=text_cfg.no_causal_mask,
pad_id=text_cfg.pad_id,
eos_id=text_cfg.eos_id,
pool_type=text_cfg.pool_type,
proj_type=text_cfg.proj_type,
proj_bias=text_cfg.proj_bias,
output_tokens=text_cfg.output_tokens,
act_layer=act_layer,
norm_layer=norm_layer,
block_type=text_cfg.block_type,
qk_norm=text_cfg.qk_norm,
scaled_cosine_attn=text_cfg.scaled_cosine_attn,
scale_heads=text_cfg.scale_heads,
scale_attn_inner=text_cfg.scale_attn_inner,
scale_attn=text_cfg.scale_attn,
scale_fc=text_cfg.scale_fc,
)
return text
class CustomTextCLIP(nn.Module):
output_dict: torch.jit.Final[bool]
def __init__(
self,
embed_dim: int,
vision_cfg: CLIPVisionCfg,
text_cfg: CLIPTextCfg,
quick_gelu: bool = False,
init_logit_scale: float = np.log(1 / 0.07),
init_logit_bias: Optional[float] = None,
nonscalar_logit_scale: bool = False,
cast_dtype: Optional[torch.dtype] = None,
output_dict: bool = False,
):
super().__init__()
self.output_dict = output_dict
self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
self.context_length = self.text.context_length
self.vocab_size = self.text.vocab_size
lshape = [1] if nonscalar_logit_scale else []
self.logit_scale = nn.Parameter(torch.ones(lshape) * init_logit_scale)
if init_logit_bias is not None:
self.logit_bias = nn.Parameter(torch.ones(lshape) * init_logit_bias)
else:
self.logit_bias = None
def encode_image(
self, pixel_values, normalize: bool = False, pixel_attention_mask=None, spatial_shapes=None
):
kwargs = {}
if pixel_attention_mask is not None:
kwargs["patch_valid_mask"] = pixel_attention_mask
if spatial_shapes is not None:
kwargs["spatial_shapes"] = spatial_shapes
features = self.visual(pixel_values, **kwargs) if kwargs else self.visual(pixel_values)
return F.normalize(features, dim=-1) if normalize else features
def encode_text(self, input_ids, normalize: bool = False):
features = self.text(input_ids)
return F.normalize(features, dim=-1) if normalize else features
def get_logits(self, image, text):
image_features = self.encode_image(pixel_values=image, normalize=True)
text_features = self.encode_text(input_ids=text, normalize=True)
image_logits = self.logit_scale.exp() * image_features @ text_features.T
if self.logit_bias is not None:
image_logits += self.logit_bias
text_logits = image_logits.T
return image_logits, text_logits
def forward(
self, image=None, text=None, patch_valid_mask=None, spatial_shapes=None
):
image_features = (
self.encode_image(
pixel_values=image,
normalize=True,
pixel_attention_mask=patch_valid_mask,
spatial_shapes=spatial_shapes,
)
if image is not None
else None
)
text_features = (
self.encode_text(input_ids=text, normalize=True) if text is not None else None
)
if self.output_dict:
out_dict = {
"image_features": image_features,
"text_features": text_features,
"logit_scale": self.logit_scale.exp(),
}
if self.logit_bias is not None:
out_dict["logit_bias"] = self.logit_bias
return out_dict
if self.logit_bias is not None:
return (
image_features,
text_features,
self.logit_scale.exp(),
self.logit_bias,
)
return image_features, text_features, self.logit_scale.exp()