Initial upload of simplicityprevails from local project
Browse files- README.md +83 -0
- core/__init__.py +0 -0
- core/vision_encoder/__init__.py +0 -0
- core/vision_encoder/bpe_simple_vocab_16e6.txt.gz +3 -0
- core/vision_encoder/config.py +261 -0
- core/vision_encoder/pe.py +761 -0
- core/vision_encoder/rope.py +347 -0
- core/vision_encoder/tokenizer.py +342 -0
- core/vision_encoder/transforms.py +31 -0
- models.py +331 -0
- test_vfm_baselines.py +153 -0
- weights/dinov2lin0.pth +3 -0
- weights/dinov3lin0.pth +3 -0
- weights/metaclip2lin0.pth +3 -0
- weights/metacliplin0.pth +3 -0
- weights/pelin0.pth +3 -0
- weights/siglip2lin0.pth +3 -0
- weights/sigliplin0.pth +3 -0
README.md
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# VFM Baselines Release
|
| 2 |
+
|
| 3 |
+
This directory contains the 7 vision foundation model baselines used in the paper:
|
| 4 |
+
|
| 5 |
+
- `MetaCLIP-Linear`
|
| 6 |
+
- `MetaCLIP2-Linear`
|
| 7 |
+
- `SigLIP-Linear`
|
| 8 |
+
- `SigLIP2-Linear`
|
| 9 |
+
- `PE-CLIP-Linear`
|
| 10 |
+
- `DINOv2-Linear`
|
| 11 |
+
- `DINOv3-Linear`
|
| 12 |
+
|
| 13 |
+
## Contents
|
| 14 |
+
|
| 15 |
+
- `models.py`: unified model-loading code for all 7 baselines
|
| 16 |
+
- `test_vfm_baselines.py`: unified evaluation script
|
| 17 |
+
- `weights/`: released checkpoints
|
| 18 |
+
- `core/vision_encoder/`: vendored PE vision encoder code required by `PE-CLIP-Linear`
|
| 19 |
+
|
| 20 |
+
## Model Names
|
| 21 |
+
|
| 22 |
+
The unified loader and test script accept these names:
|
| 23 |
+
|
| 24 |
+
- `metacliplin`
|
| 25 |
+
- `metaclip2lin`
|
| 26 |
+
- `sigliplin`
|
| 27 |
+
- `siglip2lin`
|
| 28 |
+
- `pelin`
|
| 29 |
+
- `dinov2lin`
|
| 30 |
+
- `dinov3lin`
|
| 31 |
+
|
| 32 |
+
The paper names such as `MetaCLIP-Linear` and `DINOv3-Linear` are also accepted.
|
| 33 |
+
|
| 34 |
+
## Usage
|
| 35 |
+
|
| 36 |
+
Evaluate a single model:
|
| 37 |
+
|
| 38 |
+
```bash
|
| 39 |
+
python test_vfm_baselines.py \
|
| 40 |
+
--model sigliplin \
|
| 41 |
+
--real-dir /path/to/0_real \
|
| 42 |
+
--fake-dir /path/to/1_fake \
|
| 43 |
+
--max-samples 100
|
| 44 |
+
```
|
| 45 |
+
|
| 46 |
+
Evaluate all 7 models:
|
| 47 |
+
|
| 48 |
+
```bash
|
| 49 |
+
python test_vfm_baselines.py \
|
| 50 |
+
--model all \
|
| 51 |
+
--real-dir /path/to/0_real \
|
| 52 |
+
--fake-dir /path/to/1_fake \
|
| 53 |
+
--max-samples 100
|
| 54 |
+
```
|
| 55 |
+
|
| 56 |
+
Optional arguments:
|
| 57 |
+
|
| 58 |
+
- `--checkpoint`: override the default checkpoint for single-model evaluation
|
| 59 |
+
- `--batch-size`: batch size for evaluation
|
| 60 |
+
- `--num-workers`: dataloader workers
|
| 61 |
+
- `--device`: explicit device such as `cuda:0` or `cpu`
|
| 62 |
+
- `--save-json`: save results to a JSON file
|
| 63 |
+
|
| 64 |
+
## Dependencies
|
| 65 |
+
|
| 66 |
+
The release code expects these Python packages:
|
| 67 |
+
|
| 68 |
+
- `torch`
|
| 69 |
+
- `torchvision`
|
| 70 |
+
- `transformers`
|
| 71 |
+
- `scikit-learn`
|
| 72 |
+
- `Pillow`
|
| 73 |
+
- `timm`
|
| 74 |
+
- `einops`
|
| 75 |
+
- `ftfy`
|
| 76 |
+
- `regex`
|
| 77 |
+
- `huggingface_hub`
|
| 78 |
+
|
| 79 |
+
## Notes
|
| 80 |
+
|
| 81 |
+
- The clip-family and DINO-family baselines instantiate the backbone from Hugging Face model configs and then load the released checkpoint.
|
| 82 |
+
- `PE-CLIP-Linear` uses the vendored `core/vision_encoder` code in this directory.
|
| 83 |
+
- The checkpoints in `weights/` are arranged locally for packaging convenience. For public release, they can be uploaded as the same filenames.
|
core/__init__.py
ADDED
|
File without changes
|
core/vision_encoder/__init__.py
ADDED
|
File without changes
|
core/vision_encoder/bpe_simple_vocab_16e6.txt.gz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
|
| 3 |
+
size 1356917
|
core/vision_encoder/config.py
ADDED
|
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
|
| 3 |
+
"""
|
| 4 |
+
Include all available vision encoder configurations.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from dataclasses import dataclass, replace
|
| 8 |
+
|
| 9 |
+
from functools import partial
|
| 10 |
+
from typing import Callable, Optional, Sequence, Tuple, List
|
| 11 |
+
|
| 12 |
+
from huggingface_hub import hf_hub_download
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def fetch_pe_checkpoint(name: str, path: Optional[str] = None):
|
| 17 |
+
path = path or f"hf://facebook/{name}:{name}.pt"
|
| 18 |
+
|
| 19 |
+
if path.startswith("hf://"):
|
| 20 |
+
# Load from huggingface
|
| 21 |
+
path = path[len("hf://"):]
|
| 22 |
+
repo, file = path.split(":")
|
| 23 |
+
|
| 24 |
+
return hf_hub_download(repo_id=repo, filename=file)
|
| 25 |
+
else:
|
| 26 |
+
return path
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@dataclass
|
| 32 |
+
class PEConfig:
|
| 33 |
+
""" Vision Tower Config. """
|
| 34 |
+
patch_size: int
|
| 35 |
+
width: int
|
| 36 |
+
layers: int
|
| 37 |
+
heads: int
|
| 38 |
+
mlp_ratio: float
|
| 39 |
+
output_dim: Optional[int]
|
| 40 |
+
|
| 41 |
+
ls_init_value: float = None
|
| 42 |
+
drop_path: float = 0.0
|
| 43 |
+
|
| 44 |
+
image_size: int = 224,
|
| 45 |
+
use_abs_posemb: bool = True
|
| 46 |
+
use_cls_token: bool = False
|
| 47 |
+
use_rope2d: bool = True
|
| 48 |
+
|
| 49 |
+
pool_type: str = "attn"
|
| 50 |
+
attn_pooler_heads: int = 8
|
| 51 |
+
|
| 52 |
+
use_ln_pre: bool = True
|
| 53 |
+
use_ln_post: bool = True
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
@dataclass
|
| 57 |
+
class PETextConfig:
|
| 58 |
+
""" Text Tower Config. """
|
| 59 |
+
context_length: int
|
| 60 |
+
width: int
|
| 61 |
+
heads: int
|
| 62 |
+
layers: int
|
| 63 |
+
|
| 64 |
+
output_dim: int
|
| 65 |
+
|
| 66 |
+
mlp_ratio: float = 4.0
|
| 67 |
+
vocab_size: int = 49408
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
PE_VISION_CONFIG = {}
|
| 73 |
+
PE_TEXT_CONFIG = {}
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
#########################################
|
| 78 |
+
# PE CORE #
|
| 79 |
+
#########################################
|
| 80 |
+
|
| 81 |
+
PE_VISION_CONFIG["PE-Core-G14-448"] = PEConfig(
|
| 82 |
+
image_size=448,
|
| 83 |
+
patch_size=14,
|
| 84 |
+
width=1536,
|
| 85 |
+
layers=50,
|
| 86 |
+
heads=16,
|
| 87 |
+
mlp_ratio=8960 / 1536,
|
| 88 |
+
pool_type="attn",
|
| 89 |
+
output_dim=1280,
|
| 90 |
+
use_cls_token=False,
|
| 91 |
+
)
|
| 92 |
+
PE_TEXT_CONFIG["PE-Core-G14-448"] = PETextConfig(
|
| 93 |
+
context_length=72,
|
| 94 |
+
width=1280,
|
| 95 |
+
heads=20,
|
| 96 |
+
layers=24,
|
| 97 |
+
output_dim=1280
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
PE_VISION_CONFIG["PE-Core-L14-336"] = PEConfig(
|
| 102 |
+
image_size=336,
|
| 103 |
+
patch_size=14,
|
| 104 |
+
width=1024,
|
| 105 |
+
layers=24,
|
| 106 |
+
heads=16,
|
| 107 |
+
mlp_ratio=4.0,
|
| 108 |
+
pool_type="attn",
|
| 109 |
+
output_dim=1024,
|
| 110 |
+
use_cls_token=True,
|
| 111 |
+
)
|
| 112 |
+
PE_TEXT_CONFIG["PE-Core-L14-336"] = PETextConfig(
|
| 113 |
+
context_length=32,
|
| 114 |
+
width=1024,
|
| 115 |
+
heads=16,
|
| 116 |
+
layers=24,
|
| 117 |
+
output_dim=1024
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
PE_VISION_CONFIG["PE-Core-B16-224"] = PEConfig(
|
| 122 |
+
image_size=224,
|
| 123 |
+
patch_size=16,
|
| 124 |
+
width=768,
|
| 125 |
+
layers=12,
|
| 126 |
+
heads=12,
|
| 127 |
+
mlp_ratio=4.0,
|
| 128 |
+
pool_type="attn",
|
| 129 |
+
output_dim=1024,
|
| 130 |
+
use_cls_token=True,
|
| 131 |
+
)
|
| 132 |
+
PE_TEXT_CONFIG["PE-Core-B16-224"] = PE_TEXT_CONFIG["PE-Core-L14-336"]
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
PE_VISION_CONFIG["PE-Core-S16-384"] = PEConfig(
|
| 138 |
+
image_size=384,
|
| 139 |
+
patch_size=16,
|
| 140 |
+
width=384,
|
| 141 |
+
layers=12,
|
| 142 |
+
heads=6,
|
| 143 |
+
mlp_ratio=4.0,
|
| 144 |
+
pool_type="attn",
|
| 145 |
+
output_dim=512,
|
| 146 |
+
use_cls_token=True,
|
| 147 |
+
)
|
| 148 |
+
PE_TEXT_CONFIG["PE-Core-S16-384"] = PETextConfig(
|
| 149 |
+
context_length=32,
|
| 150 |
+
width=512,
|
| 151 |
+
heads=8,
|
| 152 |
+
layers=12,
|
| 153 |
+
output_dim=512
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
PE_VISION_CONFIG["PE-Core-T16-384"] = PEConfig(
|
| 159 |
+
image_size=384,
|
| 160 |
+
patch_size=16,
|
| 161 |
+
width=192,
|
| 162 |
+
layers=12,
|
| 163 |
+
heads=3,
|
| 164 |
+
mlp_ratio=4.0,
|
| 165 |
+
pool_type="attn",
|
| 166 |
+
output_dim=512,
|
| 167 |
+
use_cls_token=True,
|
| 168 |
+
)
|
| 169 |
+
PE_TEXT_CONFIG["PE-Core-T16-384"] = PE_TEXT_CONFIG["PE-Core-S16-384"]
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
#########################################
|
| 178 |
+
# PE Lang #
|
| 179 |
+
#########################################
|
| 180 |
+
|
| 181 |
+
PE_VISION_CONFIG["PE-Lang-G14-448"] = replace(
|
| 182 |
+
PE_VISION_CONFIG["PE-Core-G14-448"],
|
| 183 |
+
image_size=448,
|
| 184 |
+
pool_type="none",
|
| 185 |
+
use_ln_post=False,
|
| 186 |
+
output_dim=None,
|
| 187 |
+
ls_init_value=0.1,
|
| 188 |
+
layers=47,
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
PE_VISION_CONFIG["PE-Lang-L14-448"] = replace(
|
| 192 |
+
PE_VISION_CONFIG["PE-Core-L14-336"],
|
| 193 |
+
image_size=448,
|
| 194 |
+
pool_type="none",
|
| 195 |
+
use_ln_post=False,
|
| 196 |
+
output_dim=None,
|
| 197 |
+
ls_init_value=0.1,
|
| 198 |
+
layers=23
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
# Stage 2 checkpoints for PLM-8B and PLM-3B respectively. Pretrained with tiling.
|
| 203 |
+
# Use these checkpoints if you're building a model that uses tiling downstream!
|
| 204 |
+
PE_VISION_CONFIG["PE-Lang-G14-448-Tiling"] = PE_VISION_CONFIG["PE-Lang-G14-448"]
|
| 205 |
+
PE_VISION_CONFIG["PE-Lang-L14-448-Tiling"] = PE_VISION_CONFIG["PE-Lang-L14-448"]
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
#########################################
|
| 215 |
+
# PE Spatial #
|
| 216 |
+
#########################################
|
| 217 |
+
|
| 218 |
+
PE_VISION_CONFIG["PE-Spatial-G14-448"] = replace(
|
| 219 |
+
PE_VISION_CONFIG["PE-Core-G14-448"],
|
| 220 |
+
image_size=448,
|
| 221 |
+
pool_type="none",
|
| 222 |
+
use_ln_post=False,
|
| 223 |
+
output_dim=None,
|
| 224 |
+
ls_init_value=0.1,
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
# No layerscale on the smaller spatial models
|
| 228 |
+
PE_VISION_CONFIG["PE-Spatial-L14-448"] = replace(
|
| 229 |
+
PE_VISION_CONFIG["PE-Core-L14-336"],
|
| 230 |
+
image_size=448,
|
| 231 |
+
pool_type="none",
|
| 232 |
+
use_ln_post=False,
|
| 233 |
+
output_dim=None,
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
PE_VISION_CONFIG["PE-Spatial-B16-512"] = replace(
|
| 238 |
+
PE_VISION_CONFIG["PE-Core-B16-224"],
|
| 239 |
+
image_size=512,
|
| 240 |
+
pool_type="none",
|
| 241 |
+
use_ln_post=False,
|
| 242 |
+
output_dim=None,
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
PE_VISION_CONFIG["PE-Spatial-S16-512"] = replace(
|
| 247 |
+
PE_VISION_CONFIG["PE-Core-S16-384"],
|
| 248 |
+
image_size=512,
|
| 249 |
+
pool_type="none",
|
| 250 |
+
use_ln_post=False,
|
| 251 |
+
output_dim=None,
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
PE_VISION_CONFIG["PE-Spatial-T16-512"] = replace(
|
| 256 |
+
PE_VISION_CONFIG["PE-Core-T16-384"],
|
| 257 |
+
image_size=512,
|
| 258 |
+
pool_type="none",
|
| 259 |
+
use_ln_post=False,
|
| 260 |
+
output_dim=None,
|
| 261 |
+
)
|
core/vision_encoder/pe.py
ADDED
|
@@ -0,0 +1,761 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import math
|
| 3 |
+
import random
|
| 4 |
+
from collections import OrderedDict
|
| 5 |
+
from dataclasses import asdict
|
| 6 |
+
from functools import partial
|
| 7 |
+
from logging import getLogger
|
| 8 |
+
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, Literal
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
from einops import rearrange
|
| 14 |
+
from timm.layers import DropPath
|
| 15 |
+
from torch import nn
|
| 16 |
+
from torch.nn import functional as F
|
| 17 |
+
from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
|
| 18 |
+
from torch.nn.parameter import Parameter
|
| 19 |
+
from torch.utils.checkpoint import checkpoint
|
| 20 |
+
|
| 21 |
+
from core.vision_encoder.rope import Rope2D
|
| 22 |
+
from core.vision_encoder.config import PEConfig, PETextConfig, PE_VISION_CONFIG, PE_TEXT_CONFIG, fetch_pe_checkpoint
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
logger = getLogger()
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class LayerScale(nn.Module):
|
| 31 |
+
def __init__(self, dim, init_values=1e-5, inplace=False):
|
| 32 |
+
super().__init__()
|
| 33 |
+
self.inplace = inplace
|
| 34 |
+
self.dim = dim
|
| 35 |
+
self.init_values = init_values
|
| 36 |
+
|
| 37 |
+
def forward(self, x):
|
| 38 |
+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
| 39 |
+
|
| 40 |
+
def init_tensors(self):
|
| 41 |
+
self.gamma = nn.Parameter(self.init_values * torch.ones(self.dim))
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class AttentionPooling(nn.Module):
|
| 45 |
+
def __init__(
|
| 46 |
+
self,
|
| 47 |
+
embed_dim: int,
|
| 48 |
+
num_heads: int,
|
| 49 |
+
num_probe: int = 1,
|
| 50 |
+
mlp_ratio: int = 4,
|
| 51 |
+
act_layer: Callable = nn.GELU,
|
| 52 |
+
norm_layer: Callable = nn.LayerNorm,
|
| 53 |
+
):
|
| 54 |
+
super().__init__()
|
| 55 |
+
|
| 56 |
+
self.embed_dim = embed_dim
|
| 57 |
+
self.num_heads = num_heads
|
| 58 |
+
|
| 59 |
+
assert (
|
| 60 |
+
self.embed_dim % num_heads == 0
|
| 61 |
+
), "embed_dim must be divisible by num_heads"
|
| 62 |
+
|
| 63 |
+
self.probe = nn.Parameter(torch.randn(1, num_probe, self.embed_dim))
|
| 64 |
+
self.attn = nn.MultiheadAttention(
|
| 65 |
+
self.embed_dim, self.num_heads, batch_first=True
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
self.layernorm = norm_layer(embed_dim)
|
| 69 |
+
self.mlp_width = int(embed_dim * mlp_ratio)
|
| 70 |
+
self.mlp = nn.Sequential(
|
| 71 |
+
OrderedDict(
|
| 72 |
+
[
|
| 73 |
+
("c_fc", nn.Linear(self.embed_dim, self.mlp_width)),
|
| 74 |
+
("gelu", act_layer()),
|
| 75 |
+
("c_proj", nn.Linear(self.mlp_width, self.embed_dim)),
|
| 76 |
+
]
|
| 77 |
+
)
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
def forward(self, x: torch.Tensor):
|
| 81 |
+
batch, _, _ = x.shape
|
| 82 |
+
|
| 83 |
+
q = self.probe.repeat((batch, 1, 1)).to(x.dtype)
|
| 84 |
+
x = self.attn(q, x, x, need_weights=False)[0]
|
| 85 |
+
x = x + self.mlp(self.layernorm(x))
|
| 86 |
+
|
| 87 |
+
return x
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class SelfAttention(nn.Module):
|
| 91 |
+
r"""
|
| 92 |
+
Implements sequence packed attention and RoPe
|
| 93 |
+
"""
|
| 94 |
+
|
| 95 |
+
def __init__(
|
| 96 |
+
self,
|
| 97 |
+
embed_dim: int,
|
| 98 |
+
num_heads: int,
|
| 99 |
+
rope: Optional[nn.Module] = None,
|
| 100 |
+
):
|
| 101 |
+
super(SelfAttention, self).__init__()
|
| 102 |
+
self.embed_dim = embed_dim
|
| 103 |
+
|
| 104 |
+
self.num_heads = num_heads
|
| 105 |
+
self.head_dim = embed_dim // num_heads
|
| 106 |
+
assert (
|
| 107 |
+
self.head_dim * num_heads == self.embed_dim
|
| 108 |
+
), "embed_dim must be divisible by num_heads"
|
| 109 |
+
|
| 110 |
+
# To make this compatibile with nn.MultiHeadAttention
|
| 111 |
+
self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim))
|
| 112 |
+
self.in_proj_bias = Parameter(torch.empty(3 * embed_dim))
|
| 113 |
+
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
|
| 114 |
+
|
| 115 |
+
self.rope = rope
|
| 116 |
+
self.scale = self.head_dim ** (-0.5)
|
| 117 |
+
|
| 118 |
+
def init_tensors(self):
|
| 119 |
+
xavier_uniform_(self.in_proj_weight)
|
| 120 |
+
constant_(self.in_proj_bias, 0.0)
|
| 121 |
+
constant_(self.out_proj.bias, 0.0)
|
| 122 |
+
|
| 123 |
+
def forward(self, x, attn_mask=None):
|
| 124 |
+
batch, seq, embed_dim = x.shape
|
| 125 |
+
proj = F.linear(x, self.in_proj_weight, self.in_proj_bias)
|
| 126 |
+
|
| 127 |
+
# reshape to 3, E and not E, 3 is deliberate for better memory coalescing and keeping same order as chunk()
|
| 128 |
+
proj = (
|
| 129 |
+
proj.unflatten(-1, (3, embed_dim))
|
| 130 |
+
.unsqueeze(0)
|
| 131 |
+
.transpose(0, -2)
|
| 132 |
+
.squeeze(-2)
|
| 133 |
+
.contiguous()
|
| 134 |
+
)
|
| 135 |
+
q, k, v = proj[0], proj[1], proj[2]
|
| 136 |
+
|
| 137 |
+
# Use "q_" so that we don't accidentally quit in pdb :)
|
| 138 |
+
q = rearrange(q, "b s (h d) -> b h s d", h=self.num_heads)
|
| 139 |
+
k = rearrange(k, "b s (h d) -> b h s d", h=self.num_heads)
|
| 140 |
+
v = rearrange(v, "b s (h d) -> b h s d", h=self.num_heads)
|
| 141 |
+
|
| 142 |
+
if self.rope:
|
| 143 |
+
q, k = self.rope(q, k)
|
| 144 |
+
|
| 145 |
+
attn = F.scaled_dot_product_attention(
|
| 146 |
+
q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False, scale=self.scale
|
| 147 |
+
)
|
| 148 |
+
attn = rearrange(attn, "b h s d -> b s (h d)")
|
| 149 |
+
|
| 150 |
+
return F.linear(attn, self.out_proj.weight, self.out_proj.bias)
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
class ResidualAttentionBlock(nn.Module):
|
| 154 |
+
def __init__(
|
| 155 |
+
self,
|
| 156 |
+
d_model: int,
|
| 157 |
+
n_head: int,
|
| 158 |
+
mlp_ratio: float = 4.0,
|
| 159 |
+
ls_init_value: float = None,
|
| 160 |
+
act_layer: Callable = nn.GELU,
|
| 161 |
+
norm_layer: Callable = nn.LayerNorm,
|
| 162 |
+
drop_path: float = 0.0,
|
| 163 |
+
rope: Optional[nn.Module] = None,
|
| 164 |
+
):
|
| 165 |
+
super().__init__()
|
| 166 |
+
|
| 167 |
+
if rope:
|
| 168 |
+
self.attn = SelfAttention(d_model, n_head, rope=rope)
|
| 169 |
+
else:
|
| 170 |
+
self.attn = nn.MultiheadAttention(d_model, n_head, batch_first=True)
|
| 171 |
+
|
| 172 |
+
self.ls_1 = (
|
| 173 |
+
LayerScale(d_model, ls_init_value)
|
| 174 |
+
if ls_init_value is not None
|
| 175 |
+
else nn.Identity()
|
| 176 |
+
)
|
| 177 |
+
self.ls_2 = (
|
| 178 |
+
LayerScale(d_model, ls_init_value)
|
| 179 |
+
if ls_init_value is not None
|
| 180 |
+
else nn.Identity()
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
self.ln_1 = norm_layer(d_model)
|
| 184 |
+
self.ln_2 = norm_layer(d_model)
|
| 185 |
+
|
| 186 |
+
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 187 |
+
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 188 |
+
|
| 189 |
+
mlp_width = int(d_model * mlp_ratio)
|
| 190 |
+
self.mlp = nn.Sequential(
|
| 191 |
+
OrderedDict(
|
| 192 |
+
[
|
| 193 |
+
("c_fc", nn.Linear(d_model, mlp_width)),
|
| 194 |
+
("gelu", act_layer()),
|
| 195 |
+
("c_proj", nn.Linear(mlp_width, d_model)),
|
| 196 |
+
]
|
| 197 |
+
)
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
def _call_attn(
|
| 201 |
+
self,
|
| 202 |
+
q_x: torch.Tensor,
|
| 203 |
+
attn_mask: Optional[torch.Tensor] = None,
|
| 204 |
+
):
|
| 205 |
+
|
| 206 |
+
if attn_mask is not None:
|
| 207 |
+
# Leave boolean masks as is
|
| 208 |
+
if not attn_mask.dtype == torch.bool:
|
| 209 |
+
attn_mask = attn_mask.to(q_x.dtype)
|
| 210 |
+
|
| 211 |
+
if isinstance(self.attn, SelfAttention):
|
| 212 |
+
return self.attn(q_x, attn_mask=attn_mask)
|
| 213 |
+
else:
|
| 214 |
+
return self.attn(q_x, q_x, q_x, attn_mask=attn_mask, need_weights=False)[0]
|
| 215 |
+
|
| 216 |
+
def forward(
|
| 217 |
+
self,
|
| 218 |
+
x: torch.Tensor,
|
| 219 |
+
attn_mask: Optional[torch.Tensor] = None,
|
| 220 |
+
):
|
| 221 |
+
x = x + self.drop_path1(
|
| 222 |
+
self.ls_1(self._call_attn(self.ln_1(x), attn_mask=attn_mask))
|
| 223 |
+
)
|
| 224 |
+
x = x + self.drop_path2(self.ls_2(self.mlp(self.ln_2(x))))
|
| 225 |
+
return x
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
class Transformer(nn.Module):
|
| 229 |
+
def __init__(
|
| 230 |
+
self,
|
| 231 |
+
width: int,
|
| 232 |
+
layers: int,
|
| 233 |
+
heads: int,
|
| 234 |
+
mlp_ratio: float = 4.0,
|
| 235 |
+
ls_init_value: float = None,
|
| 236 |
+
act_layer: Callable = nn.GELU,
|
| 237 |
+
norm_layer: Callable = nn.LayerNorm,
|
| 238 |
+
drop_path: float = 0.0,
|
| 239 |
+
rope: Optional[nn.Module] = None,
|
| 240 |
+
):
|
| 241 |
+
super().__init__()
|
| 242 |
+
self.width = width
|
| 243 |
+
self.layers = layers
|
| 244 |
+
self.grad_checkpointing = False
|
| 245 |
+
|
| 246 |
+
self.resblocks = nn.ModuleList(
|
| 247 |
+
[
|
| 248 |
+
ResidualAttentionBlock(
|
| 249 |
+
width,
|
| 250 |
+
heads,
|
| 251 |
+
mlp_ratio,
|
| 252 |
+
ls_init_value=ls_init_value,
|
| 253 |
+
act_layer=act_layer,
|
| 254 |
+
norm_layer=norm_layer,
|
| 255 |
+
drop_path=drop_path,
|
| 256 |
+
rope=rope,
|
| 257 |
+
)
|
| 258 |
+
for _ in range(layers)
|
| 259 |
+
]
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
@torch.jit.ignore
|
| 263 |
+
def set_grad_checkpointing(self, enable=True):
|
| 264 |
+
self.grad_checkpointing = enable
|
| 265 |
+
|
| 266 |
+
@torch.jit.ignore
|
| 267 |
+
def truncate(self, layer_idx: int):
|
| 268 |
+
""" Delete layers so the last layer is the given layer index. """
|
| 269 |
+
self.layers = ((self.layers + layer_idx) % self.layers) + 1
|
| 270 |
+
self.resblocks = nn.ModuleList(self.resblocks[:self.layers])
|
| 271 |
+
|
| 272 |
+
def forward(
|
| 273 |
+
self,
|
| 274 |
+
x: torch.Tensor,
|
| 275 |
+
attn_mask: Optional[torch.Tensor] = None,
|
| 276 |
+
layer_idx: int = -1,
|
| 277 |
+
):
|
| 278 |
+
stop_idx = (self.layers + layer_idx) % self.layers
|
| 279 |
+
|
| 280 |
+
for i, r in enumerate(self.resblocks):
|
| 281 |
+
if self.grad_checkpointing and not torch.jit.is_scripting():
|
| 282 |
+
# TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372
|
| 283 |
+
x = checkpoint(r, x, None, None, attn_mask)
|
| 284 |
+
else:
|
| 285 |
+
x = r(x, attn_mask=attn_mask)
|
| 286 |
+
|
| 287 |
+
if i == stop_idx:
|
| 288 |
+
break
|
| 289 |
+
|
| 290 |
+
return x
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
class VisionTransformer(nn.Module):
|
| 294 |
+
def __init__(
|
| 295 |
+
self,
|
| 296 |
+
patch_size: int,
|
| 297 |
+
width: int,
|
| 298 |
+
layers: int,
|
| 299 |
+
heads: int,
|
| 300 |
+
mlp_ratio: float,
|
| 301 |
+
act_layer: Callable = nn.GELU,
|
| 302 |
+
norm_layer: Callable = partial(nn.LayerNorm, eps=1e-5),
|
| 303 |
+
use_ln_pre: bool = True,
|
| 304 |
+
use_ln_post: bool = True,
|
| 305 |
+
ls_init_value: float = None,
|
| 306 |
+
drop_path: float = 0.0,
|
| 307 |
+
image_size: int = 448, # Pretrain image size only; you can pass in any image size
|
| 308 |
+
use_abs_posemb: bool = True,
|
| 309 |
+
use_rope2d: bool = True,
|
| 310 |
+
use_cls_token: bool = False,
|
| 311 |
+
output_dim: Optional[int] = 1280,
|
| 312 |
+
attn_pooler_heads: int = 8,
|
| 313 |
+
pool_type: Literal["attn", "tok", "avg", "none"] = "attn",
|
| 314 |
+
):
|
| 315 |
+
super().__init__()
|
| 316 |
+
assert pool_type in ("attn", "tok", "avg", "none")
|
| 317 |
+
self.pool_type = pool_type
|
| 318 |
+
self.patch_size = patch_size
|
| 319 |
+
|
| 320 |
+
self.output_dim = output_dim or width
|
| 321 |
+
self.proj_dim = output_dim
|
| 322 |
+
self.heads = heads
|
| 323 |
+
self.width = width
|
| 324 |
+
self.layers = layers
|
| 325 |
+
|
| 326 |
+
self.use_abs_posemb = use_abs_posemb
|
| 327 |
+
self.use_cls_token = use_cls_token
|
| 328 |
+
self.use_rope2d = use_rope2d
|
| 329 |
+
self.image_size = image_size
|
| 330 |
+
|
| 331 |
+
self.conv1 = nn.Conv2d(
|
| 332 |
+
in_channels=3,
|
| 333 |
+
out_channels=width,
|
| 334 |
+
kernel_size=patch_size,
|
| 335 |
+
stride=patch_size,
|
| 336 |
+
bias=False,
|
| 337 |
+
)
|
| 338 |
+
self.rope = (
|
| 339 |
+
Rope2D(
|
| 340 |
+
dim=width // heads,
|
| 341 |
+
use_cls_token=self.use_cls_token,
|
| 342 |
+
)
|
| 343 |
+
if self.use_rope2d
|
| 344 |
+
else None
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
self.ln_pre = norm_layer(width) if use_ln_pre else nn.Identity()
|
| 348 |
+
self.ln_post = norm_layer(self.width) if use_ln_post else nn.Identity()
|
| 349 |
+
|
| 350 |
+
self.transformer = Transformer(
|
| 351 |
+
width,
|
| 352 |
+
layers,
|
| 353 |
+
heads,
|
| 354 |
+
mlp_ratio,
|
| 355 |
+
ls_init_value=ls_init_value,
|
| 356 |
+
act_layer=act_layer,
|
| 357 |
+
norm_layer=norm_layer,
|
| 358 |
+
drop_path=drop_path,
|
| 359 |
+
rope=self.rope,
|
| 360 |
+
)
|
| 361 |
+
|
| 362 |
+
if pool_type == "attn":
|
| 363 |
+
self.attn_pool = AttentionPooling(
|
| 364 |
+
embed_dim=width,
|
| 365 |
+
num_heads=attn_pooler_heads,
|
| 366 |
+
act_layer=act_layer,
|
| 367 |
+
norm_layer=norm_layer,
|
| 368 |
+
)
|
| 369 |
+
else:
|
| 370 |
+
self.attn_pool = None
|
| 371 |
+
|
| 372 |
+
self.init_tensors()
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
def init_tensors(self):
|
| 376 |
+
def init_submodule_tensors(module):
|
| 377 |
+
for name, child in module.named_children():
|
| 378 |
+
if hasattr(child, "init_tensors"):
|
| 379 |
+
logger.debug(f"Initializing tensors for submodule: {name}")
|
| 380 |
+
child.init_tensors()
|
| 381 |
+
init_submodule_tensors(child)
|
| 382 |
+
|
| 383 |
+
init_submodule_tensors(self)
|
| 384 |
+
self.rope.init_tensors()
|
| 385 |
+
|
| 386 |
+
# class embeddings and positional embeddings
|
| 387 |
+
init_scale = self.width**-0.5
|
| 388 |
+
|
| 389 |
+
if self.use_cls_token:
|
| 390 |
+
self.class_embedding = nn.Parameter(init_scale * torch.randn(self.width))
|
| 391 |
+
|
| 392 |
+
if self.use_abs_posemb:
|
| 393 |
+
self.posemb_grid_size = self.image_size // self.patch_size
|
| 394 |
+
self.positional_embedding = nn.Parameter(
|
| 395 |
+
init_scale
|
| 396 |
+
* torch.randn(
|
| 397 |
+
int(self.use_cls_token) + self.posemb_grid_size**2, self.width
|
| 398 |
+
)
|
| 399 |
+
)
|
| 400 |
+
|
| 401 |
+
if self.proj_dim is not None:
|
| 402 |
+
self.proj = nn.Parameter(
|
| 403 |
+
init_scale * torch.randn(self.width, self.proj_dim)
|
| 404 |
+
)
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
def load_ckpt(self, ckpt_path: str, verbose: bool = True):
|
| 408 |
+
_sd = torch.load(ckpt_path, weights_only=True)
|
| 409 |
+
if "state_dict" in _sd:
|
| 410 |
+
_sd = _sd["state_dict"]
|
| 411 |
+
elif "weights" in _sd:
|
| 412 |
+
_sd = _sd["weights"]
|
| 413 |
+
|
| 414 |
+
# for backwards compatibility
|
| 415 |
+
_sd = {k.replace("module.", ""): v for k, v in _sd.items()}
|
| 416 |
+
if any(k.startswith("visual.") for k in _sd):
|
| 417 |
+
_sd = {k.replace("visual.", ""): v for k, v in _sd.items() if "visual" in k}
|
| 418 |
+
|
| 419 |
+
m, u = self.load_state_dict(_sd, strict=False)
|
| 420 |
+
|
| 421 |
+
if verbose or (m or u):
|
| 422 |
+
logger.info(f"Missing keys for loading vision encoder: {m}")
|
| 423 |
+
logger.info(f"Unexpected keys for loading vision encoder: {u}")
|
| 424 |
+
print(f"Missing keys for loading vision encoder: {m}")
|
| 425 |
+
print(f"Unexpected keys for loading vision encoder: {u}")
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
def truncate(self, layer_idx: int):
|
| 429 |
+
""" Delete layers so the last layer is the given layer index. """
|
| 430 |
+
self.transformer.truncate(layer_idx)
|
| 431 |
+
self.layers = self.transformer.layers
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
@classmethod
|
| 435 |
+
def from_config(
|
| 436 |
+
cls,
|
| 437 |
+
name: str,
|
| 438 |
+
pretrained: bool = False,
|
| 439 |
+
checkpoint_path: Optional[str] = None,
|
| 440 |
+
**kwdargs
|
| 441 |
+
):
|
| 442 |
+
if name not in PE_VISION_CONFIG:
|
| 443 |
+
raise RuntimeError(f"{name} not found in configs.")
|
| 444 |
+
|
| 445 |
+
args = asdict(PE_VISION_CONFIG[name])
|
| 446 |
+
args.update(kwdargs)
|
| 447 |
+
|
| 448 |
+
model = cls(**args)
|
| 449 |
+
if pretrained:
|
| 450 |
+
model.load_ckpt(fetch_pe_checkpoint(name, checkpoint_path))
|
| 451 |
+
|
| 452 |
+
return model
|
| 453 |
+
|
| 454 |
+
@classmethod
|
| 455 |
+
def available_configs(cls):
|
| 456 |
+
return list(PE_VISION_CONFIG.keys())
|
| 457 |
+
|
| 458 |
+
|
| 459 |
+
@torch.jit.ignore
|
| 460 |
+
def set_grad_checkpointing(self, enable=True):
|
| 461 |
+
self.transformer.set_grad_checkpointing(enable=enable)
|
| 462 |
+
|
| 463 |
+
def _sample_abs_posemb(self, grid_h: int, grid_w: int):
|
| 464 |
+
"""Interpolates the absolute position embedding if necessary."""
|
| 465 |
+
if self.posemb_grid_size == grid_h and self.posemb_grid_size == grid_w:
|
| 466 |
+
return self.positional_embedding[None, ...]
|
| 467 |
+
|
| 468 |
+
pos_embed = self.positional_embedding
|
| 469 |
+
if self.use_cls_token:
|
| 470 |
+
cls_token_embed, pos_embed = pos_embed[:1], pos_embed[1:]
|
| 471 |
+
|
| 472 |
+
pos_embed = (
|
| 473 |
+
pos_embed.reshape(1, self.posemb_grid_size, self.posemb_grid_size, -1)
|
| 474 |
+
.permute(0, 3, 1, 2)
|
| 475 |
+
.contiguous()
|
| 476 |
+
)
|
| 477 |
+
pos_embed = F.interpolate(
|
| 478 |
+
pos_embed, size=(grid_h, grid_w), mode="bilinear", align_corners=False
|
| 479 |
+
)
|
| 480 |
+
pos_embed = pos_embed.permute(0, 2, 3, 1).reshape(-1, self.width).contiguous()
|
| 481 |
+
|
| 482 |
+
if self.use_cls_token:
|
| 483 |
+
pos_embed = torch.cat([cls_token_embed, pos_embed], dim=0)
|
| 484 |
+
|
| 485 |
+
return pos_embed[None, ...]
|
| 486 |
+
|
| 487 |
+
def _pool(self, x: torch.Tensor):
|
| 488 |
+
if self.pool_type == "tok":
|
| 489 |
+
return x[:, 0]
|
| 490 |
+
elif self.pool_type == "avg":
|
| 491 |
+
return x.mean(dim=1)
|
| 492 |
+
elif self.pool_type == "attn":
|
| 493 |
+
return self.attn_pool(x).squeeze(1)
|
| 494 |
+
elif self.pool_type == "none":
|
| 495 |
+
return x
|
| 496 |
+
else:
|
| 497 |
+
raise NotImplementedError
|
| 498 |
+
|
| 499 |
+
def forward_features(
|
| 500 |
+
self,
|
| 501 |
+
x: torch.Tensor,
|
| 502 |
+
norm: bool = False,
|
| 503 |
+
layer_idx: int = -1,
|
| 504 |
+
strip_cls_token: bool = False
|
| 505 |
+
):
|
| 506 |
+
batch, _, h, w = x.shape
|
| 507 |
+
grid_h, grid_w = h // self.patch_size, w // self.patch_size
|
| 508 |
+
|
| 509 |
+
x = self.conv1(x)
|
| 510 |
+
x = x.permute(0, 2, 3, 1).reshape(batch, -1, self.width)
|
| 511 |
+
|
| 512 |
+
if self.use_cls_token:
|
| 513 |
+
x = torch.cat(
|
| 514 |
+
[self.class_embedding.view(1, 1, -1).expand(batch, -1, -1), x],
|
| 515 |
+
dim=1,
|
| 516 |
+
)
|
| 517 |
+
|
| 518 |
+
if self.use_abs_posemb:
|
| 519 |
+
x = x + self._sample_abs_posemb(grid_h, grid_w)
|
| 520 |
+
|
| 521 |
+
if self.use_rope2d:
|
| 522 |
+
self.rope.update_grid(x.device, grid_h, grid_w)
|
| 523 |
+
|
| 524 |
+
x = self.ln_pre(x)
|
| 525 |
+
x = self.transformer(x, layer_idx=layer_idx)
|
| 526 |
+
|
| 527 |
+
if norm:
|
| 528 |
+
x = self.ln_post(x)
|
| 529 |
+
|
| 530 |
+
if strip_cls_token and self.use_cls_token:
|
| 531 |
+
x = x[:, 1:, :]
|
| 532 |
+
|
| 533 |
+
return x
|
| 534 |
+
|
| 535 |
+
def forward(self, x: torch.Tensor, **kwargs):
|
| 536 |
+
x = self.forward_features(x, norm=True, **kwargs)
|
| 537 |
+
x = self._pool(x)
|
| 538 |
+
|
| 539 |
+
if self.proj_dim is not None:
|
| 540 |
+
x = x @ self.proj
|
| 541 |
+
|
| 542 |
+
return x
|
| 543 |
+
|
| 544 |
+
|
| 545 |
+
|
| 546 |
+
|
| 547 |
+
|
| 548 |
+
|
| 549 |
+
|
| 550 |
+
|
| 551 |
+
|
| 552 |
+
class TextTransformer(nn.Module):
|
| 553 |
+
def __init__(
|
| 554 |
+
self,
|
| 555 |
+
context_length: int = 72,
|
| 556 |
+
vocab_size: int = 49408,
|
| 557 |
+
width: int = 512,
|
| 558 |
+
heads: int = 8,
|
| 559 |
+
layers: int = 12,
|
| 560 |
+
mlp_ratio: float = 4.0,
|
| 561 |
+
ls_init_value: float = None,
|
| 562 |
+
output_dim: int = 1280,
|
| 563 |
+
no_causal_mask: bool = False,
|
| 564 |
+
pad_id: int = 0,
|
| 565 |
+
pool_type: str = "argmax",
|
| 566 |
+
proj_bias: bool = False,
|
| 567 |
+
act_layer: Callable = nn.GELU,
|
| 568 |
+
norm_layer: Callable = partial(nn.LayerNorm, eps=1e-5),
|
| 569 |
+
output_tokens: bool = False,
|
| 570 |
+
use_ln_post: bool = True,
|
| 571 |
+
):
|
| 572 |
+
super().__init__()
|
| 573 |
+
assert pool_type in ("first", "last", "argmax", "none")
|
| 574 |
+
self.pool_type = pool_type
|
| 575 |
+
self.output_tokens = output_tokens
|
| 576 |
+
self.num_pos = self.context_length = context_length
|
| 577 |
+
self.vocab_size = vocab_size
|
| 578 |
+
self.width = width
|
| 579 |
+
self.output_dim = output_dim
|
| 580 |
+
self.heads = heads
|
| 581 |
+
self.pad_id = pad_id
|
| 582 |
+
self.layers = layers
|
| 583 |
+
|
| 584 |
+
self.token_embedding = nn.Embedding(vocab_size, width)
|
| 585 |
+
self.positional_embedding = nn.Parameter(torch.empty(self.num_pos, width))
|
| 586 |
+
|
| 587 |
+
self.transformer = Transformer(
|
| 588 |
+
width=width,
|
| 589 |
+
layers=layers,
|
| 590 |
+
heads=heads,
|
| 591 |
+
mlp_ratio=mlp_ratio,
|
| 592 |
+
ls_init_value=ls_init_value,
|
| 593 |
+
act_layer=act_layer,
|
| 594 |
+
norm_layer=norm_layer,
|
| 595 |
+
)
|
| 596 |
+
|
| 597 |
+
self.ln_final = norm_layer(width) if use_ln_post else nn.Identity()
|
| 598 |
+
|
| 599 |
+
if no_causal_mask:
|
| 600 |
+
self.attn_mask = None
|
| 601 |
+
else:
|
| 602 |
+
self.register_buffer(
|
| 603 |
+
"attn_mask", self.build_causal_mask(), persistent=False
|
| 604 |
+
)
|
| 605 |
+
|
| 606 |
+
if pool_type == "attn" or pool_type == "attn_eos":
|
| 607 |
+
self.attn_pool = AttentionPooling(
|
| 608 |
+
embed_dim=width,
|
| 609 |
+
num_heads=heads,
|
| 610 |
+
act_layer=act_layer,
|
| 611 |
+
norm_layer=norm_layer,
|
| 612 |
+
)
|
| 613 |
+
else: # argmax
|
| 614 |
+
self.attn_pool = None
|
| 615 |
+
|
| 616 |
+
if proj_bias:
|
| 617 |
+
self.text_projection = nn.Linear(width, output_dim)
|
| 618 |
+
else:
|
| 619 |
+
self.text_projection = nn.Parameter(torch.empty(width, output_dim))
|
| 620 |
+
|
| 621 |
+
def build_causal_mask(self):
|
| 622 |
+
# lazily create causal attention mask, with full attention between the tokens
|
| 623 |
+
# pytorch uses additive attention mask; fill with -inf
|
| 624 |
+
mask = torch.empty(self.num_pos, self.num_pos)
|
| 625 |
+
mask.fill_(float("-inf"))
|
| 626 |
+
mask.triu_(1) # zero out the lower diagonal
|
| 627 |
+
return mask
|
| 628 |
+
|
| 629 |
+
def load_ckpt(self, ckpt_path: str, verbose: bool = True):
|
| 630 |
+
_sd = torch.load(ckpt_path, weights_only=True)
|
| 631 |
+
if "state_dict" in _sd:
|
| 632 |
+
_sd = _sd["state_dict"]
|
| 633 |
+
elif "weights" in _sd:
|
| 634 |
+
_sd = _sd["weights"]
|
| 635 |
+
|
| 636 |
+
_sd = {k.replace("module.", ""): v for k, v in _sd.items()}
|
| 637 |
+
|
| 638 |
+
m, u = self.load_state_dict(_sd, strict=False)
|
| 639 |
+
|
| 640 |
+
if verbose or (m or u):
|
| 641 |
+
logger.info(f"Missing keys for loading model: {m}")
|
| 642 |
+
logger.info(f"Unexpected keys for loading model: {u}")
|
| 643 |
+
print(f"Missing keys for loading model: {m}")
|
| 644 |
+
print(f"Unexpected keys for loading model: {u}")
|
| 645 |
+
|
| 646 |
+
def build_cls_mask(self, text):
|
| 647 |
+
cls_mask = (text != self.pad_id).unsqueeze(1)
|
| 648 |
+
cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=True)
|
| 649 |
+
additive_mask = torch.empty(cls_mask.shape, device=cls_mask.device)
|
| 650 |
+
additive_mask.fill_(0)
|
| 651 |
+
additive_mask.masked_fill_(~cls_mask, float("-inf"))
|
| 652 |
+
additive_mask = torch.repeat_interleave(additive_mask, self.heads, 0)
|
| 653 |
+
return additive_mask
|
| 654 |
+
|
| 655 |
+
def text_global_pool(
|
| 656 |
+
self, x, text: Optional[torch.Tensor] = None, pool_type: str = "argmax"
|
| 657 |
+
):
|
| 658 |
+
if pool_type == "first":
|
| 659 |
+
pooled, tokens = x[:, 0], x[:, 1:]
|
| 660 |
+
elif pool_type == "last":
|
| 661 |
+
pooled, tokens = x[:, -1], x[:, :-1]
|
| 662 |
+
elif pool_type == "argmax":
|
| 663 |
+
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
| 664 |
+
assert text is not None
|
| 665 |
+
pooled, tokens = x[torch.arange(x.shape[0]), text.argmax(dim=-1)], x
|
| 666 |
+
else:
|
| 667 |
+
pooled = tokens = x
|
| 668 |
+
|
| 669 |
+
return pooled, tokens
|
| 670 |
+
|
| 671 |
+
def forward(self, text):
|
| 672 |
+
seq_len = text.shape[1]
|
| 673 |
+
x = self.token_embedding(
|
| 674 |
+
text
|
| 675 |
+
)
|
| 676 |
+
attn_mask = self.attn_mask
|
| 677 |
+
if attn_mask is not None:
|
| 678 |
+
attn_mask = attn_mask[:seq_len, :seq_len]
|
| 679 |
+
|
| 680 |
+
x = x + self.positional_embedding[:seq_len]
|
| 681 |
+
x = self.transformer(x, attn_mask=attn_mask)
|
| 682 |
+
|
| 683 |
+
x = self.ln_final(x)
|
| 684 |
+
pooled, tokens = self.text_global_pool(x, text, pool_type=self.pool_type)
|
| 685 |
+
|
| 686 |
+
if self.text_projection is not None:
|
| 687 |
+
if isinstance(self.text_projection, nn.Linear):
|
| 688 |
+
pooled = self.text_projection(pooled)
|
| 689 |
+
else:
|
| 690 |
+
pooled = pooled @ self.text_projection
|
| 691 |
+
|
| 692 |
+
if self.output_tokens:
|
| 693 |
+
return pooled, tokens
|
| 694 |
+
|
| 695 |
+
return pooled
|
| 696 |
+
|
| 697 |
+
|
| 698 |
+
|
| 699 |
+
|
| 700 |
+
class CLIP(TextTransformer):
|
| 701 |
+
def __init__(
|
| 702 |
+
self,
|
| 703 |
+
vision_cfg: PEConfig,
|
| 704 |
+
text_cfg: PETextConfig,
|
| 705 |
+
init_logit_scale: float = np.log(1 / 0.07)
|
| 706 |
+
):
|
| 707 |
+
super(CLIP, self).__init__(**asdict(text_cfg))
|
| 708 |
+
self.visual = VisionTransformer(**asdict(vision_cfg))
|
| 709 |
+
self.image_size = self.visual.image_size # For ease of use
|
| 710 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * init_logit_scale)
|
| 711 |
+
|
| 712 |
+
|
| 713 |
+
def encode_image(self, image, normalize: bool = False):
|
| 714 |
+
x = self.visual(image)
|
| 715 |
+
return F.normalize(x, dim=-1) if normalize else x
|
| 716 |
+
|
| 717 |
+
def encode_video(self, video, normalize: bool = False): # b n c h w
|
| 718 |
+
b, n, c, h, w = video.shape
|
| 719 |
+
frms = video.reshape(b * n, c, h, w)
|
| 720 |
+
frm_feats = self.encode_image(frms, normalize=normalize)
|
| 721 |
+
video_feats = frm_feats.reshape(b, n, -1)
|
| 722 |
+
video_feats = video_feats.mean(dim=1)
|
| 723 |
+
return video_feats
|
| 724 |
+
|
| 725 |
+
def encode_text(self, text, normalize: bool = False):
|
| 726 |
+
x = super().forward(text)
|
| 727 |
+
return F.normalize(x, dim=-1) if normalize else x
|
| 728 |
+
|
| 729 |
+
def forward(
|
| 730 |
+
self,
|
| 731 |
+
image: Optional[torch.Tensor] = None,
|
| 732 |
+
text: Optional[torch.Tensor] = None,
|
| 733 |
+
):
|
| 734 |
+
image_features = (
|
| 735 |
+
self.encode_image(image, normalize=True) if image is not None else None
|
| 736 |
+
)
|
| 737 |
+
text_features = (
|
| 738 |
+
self.encode_text(text, normalize=True) if text is not None else None
|
| 739 |
+
)
|
| 740 |
+
return image_features, text_features, self.logit_scale.exp()
|
| 741 |
+
|
| 742 |
+
|
| 743 |
+
@classmethod
|
| 744 |
+
def from_config(
|
| 745 |
+
cls,
|
| 746 |
+
name: str,
|
| 747 |
+
pretrained: bool = False,
|
| 748 |
+
checkpoint_path: Optional[str] = None # To load your own
|
| 749 |
+
):
|
| 750 |
+
if name not in PE_VISION_CONFIG or name not in PE_TEXT_CONFIG:
|
| 751 |
+
raise RuntimeError(f"{name} not found in configs.")
|
| 752 |
+
|
| 753 |
+
model = cls(PE_VISION_CONFIG[name], PE_TEXT_CONFIG[name])
|
| 754 |
+
if pretrained:
|
| 755 |
+
model.load_ckpt(fetch_pe_checkpoint(name, checkpoint_path))
|
| 756 |
+
|
| 757 |
+
return model
|
| 758 |
+
|
| 759 |
+
@classmethod
|
| 760 |
+
def available_configs(cls):
|
| 761 |
+
return [k for k in PE_VISION_CONFIG if k in PE_TEXT_CONFIG]
|
core/vision_encoder/rope.py
ADDED
|
@@ -0,0 +1,347 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from math import log, pi
|
| 2 |
+
from typing import Literal, Optional, Union
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from einops import rearrange, repeat
|
| 6 |
+
from torch import Tensor, broadcast_tensors, einsum, nn
|
| 7 |
+
from torch.amp import autocast
|
| 8 |
+
from torch.nn import Module, ModuleList
|
| 9 |
+
|
| 10 |
+
# helper functions
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def exists(val):
|
| 14 |
+
return val is not None
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def default(val, d):
|
| 18 |
+
return val if exists(val) else d
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
# broadcat, as tortoise-tts was using it
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def broadcat(tensors, dim=-1):
|
| 25 |
+
broadcasted_tensors = broadcast_tensors(*tensors)
|
| 26 |
+
return torch.cat(broadcasted_tensors, dim=dim)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# rotary embedding helper functions
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def rotate_half(x):
|
| 33 |
+
x = rearrange(x, "... (d r) -> ... d r", r=2)
|
| 34 |
+
x1, x2 = x.unbind(dim=-1)
|
| 35 |
+
x = torch.stack((-x2, x1), dim=-1)
|
| 36 |
+
return rearrange(x, "... d r -> ... (d r)")
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
@autocast("cuda", enabled=False)
|
| 40 |
+
def apply_rotary_emb(freqs, t, start_index=0, scale=1.0, seq_dim=-2):
|
| 41 |
+
dtype = t.dtype
|
| 42 |
+
|
| 43 |
+
if t.ndim == 3:
|
| 44 |
+
seq_len = t.shape[seq_dim]
|
| 45 |
+
freqs = freqs[-seq_len:]
|
| 46 |
+
|
| 47 |
+
rot_dim = freqs.shape[-1]
|
| 48 |
+
end_index = start_index + rot_dim
|
| 49 |
+
|
| 50 |
+
assert (
|
| 51 |
+
rot_dim <= t.shape[-1]
|
| 52 |
+
), f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}"
|
| 53 |
+
|
| 54 |
+
t_left, t, t_right = (
|
| 55 |
+
t[..., :start_index],
|
| 56 |
+
t[..., start_index:end_index],
|
| 57 |
+
t[..., end_index:],
|
| 58 |
+
)
|
| 59 |
+
t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
|
| 60 |
+
out = torch.cat((t_left, t, t_right), dim=-1)
|
| 61 |
+
|
| 62 |
+
return out.type(dtype)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
# learned rotation helpers
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def apply_learned_rotations(rotations, t, start_index=0, freq_ranges=None):
|
| 69 |
+
if exists(freq_ranges):
|
| 70 |
+
rotations = einsum("..., f -> ... f", rotations, freq_ranges)
|
| 71 |
+
rotations = rearrange(rotations, "... r f -> ... (r f)")
|
| 72 |
+
|
| 73 |
+
rotations = repeat(rotations, "... n -> ... (n r)", r=2)
|
| 74 |
+
return apply_rotary_emb(rotations, t, start_index=start_index)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
# classes
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class RotaryEmbedding(Module):
|
| 81 |
+
def __init__(
|
| 82 |
+
self,
|
| 83 |
+
dim,
|
| 84 |
+
custom_freqs: Optional[Tensor] = None,
|
| 85 |
+
freqs_for: Union[
|
| 86 |
+
Literal["lang"], Literal["pixel"], Literal["constant"]
|
| 87 |
+
] = "lang",
|
| 88 |
+
theta=10000,
|
| 89 |
+
max_freq=10,
|
| 90 |
+
num_freqs=1,
|
| 91 |
+
learned_freq=False,
|
| 92 |
+
use_xpos=False,
|
| 93 |
+
xpos_scale_base=512,
|
| 94 |
+
interpolate_factor=1.0,
|
| 95 |
+
theta_rescale_factor=1.0,
|
| 96 |
+
seq_before_head_dim=False,
|
| 97 |
+
cache_if_possible=True,
|
| 98 |
+
):
|
| 99 |
+
super().__init__()
|
| 100 |
+
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
|
| 101 |
+
# has some connection to NTK literature
|
| 102 |
+
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
|
| 103 |
+
|
| 104 |
+
theta *= theta_rescale_factor ** (dim / (dim - 2))
|
| 105 |
+
|
| 106 |
+
self.freqs_for = freqs_for
|
| 107 |
+
|
| 108 |
+
if exists(custom_freqs):
|
| 109 |
+
freqs = custom_freqs
|
| 110 |
+
elif freqs_for == "lang":
|
| 111 |
+
freqs = 1.0 / (
|
| 112 |
+
theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
|
| 113 |
+
)
|
| 114 |
+
elif freqs_for == "pixel":
|
| 115 |
+
freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
|
| 116 |
+
elif freqs_for == "constant":
|
| 117 |
+
freqs = torch.ones(num_freqs).float()
|
| 118 |
+
|
| 119 |
+
self.cache_if_possible = cache_if_possible
|
| 120 |
+
|
| 121 |
+
self.tmp_store("cached_freqs", None)
|
| 122 |
+
self.tmp_store("cached_scales", None)
|
| 123 |
+
|
| 124 |
+
self.freqs = nn.Parameter(freqs, requires_grad=learned_freq)
|
| 125 |
+
|
| 126 |
+
self.learned_freq = learned_freq
|
| 127 |
+
|
| 128 |
+
# dummy for device
|
| 129 |
+
|
| 130 |
+
self.tmp_store("dummy", torch.tensor(0))
|
| 131 |
+
|
| 132 |
+
# default sequence dimension
|
| 133 |
+
|
| 134 |
+
self.seq_before_head_dim = seq_before_head_dim
|
| 135 |
+
self.default_seq_dim = -3 if seq_before_head_dim else -2
|
| 136 |
+
|
| 137 |
+
# interpolation factors
|
| 138 |
+
|
| 139 |
+
assert interpolate_factor >= 1.0
|
| 140 |
+
self.interpolate_factor = interpolate_factor
|
| 141 |
+
|
| 142 |
+
# xpos
|
| 143 |
+
|
| 144 |
+
self.use_xpos = use_xpos
|
| 145 |
+
if not use_xpos:
|
| 146 |
+
self.tmp_store("scale", None)
|
| 147 |
+
return
|
| 148 |
+
|
| 149 |
+
scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
|
| 150 |
+
|
| 151 |
+
self.scale_base = xpos_scale_base
|
| 152 |
+
self.tmp_store("scale", scale)
|
| 153 |
+
|
| 154 |
+
# add apply_rotary_emb as static method
|
| 155 |
+
|
| 156 |
+
self.apply_rotary_emb = staticmethod(apply_rotary_emb)
|
| 157 |
+
|
| 158 |
+
@property
|
| 159 |
+
def device(self):
|
| 160 |
+
return self.dummy.device
|
| 161 |
+
|
| 162 |
+
def tmp_store(self, key, value):
|
| 163 |
+
self.register_buffer(key, value, persistent=False)
|
| 164 |
+
|
| 165 |
+
def get_seq_pos(self, seq_len, device, dtype, offset=0):
|
| 166 |
+
return (
|
| 167 |
+
torch.arange(seq_len, device=device, dtype=dtype) + offset
|
| 168 |
+
) / self.interpolate_factor
|
| 169 |
+
|
| 170 |
+
def rotate_queries_or_keys(self, t, seq_dim=None, offset=0):
|
| 171 |
+
seq_dim = default(seq_dim, self.default_seq_dim)
|
| 172 |
+
|
| 173 |
+
assert (
|
| 174 |
+
not self.use_xpos
|
| 175 |
+
), "you must use `.rotate_queries_and_keys` method instead and pass in both queries and keys, for length extrapolatable rotary embeddings"
|
| 176 |
+
|
| 177 |
+
device, dtype, seq_len = t.device, t.dtype, t.shape[seq_dim]
|
| 178 |
+
|
| 179 |
+
freqs = self.forward(
|
| 180 |
+
self.get_seq_pos(seq_len, device=device, dtype=dtype, offset=offset),
|
| 181 |
+
seq_len=seq_len,
|
| 182 |
+
offset=offset,
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
if seq_dim == -3:
|
| 186 |
+
freqs = rearrange(freqs, "n d -> n 1 d")
|
| 187 |
+
|
| 188 |
+
return apply_rotary_emb(freqs, t, seq_dim=seq_dim)
|
| 189 |
+
|
| 190 |
+
def rotate_queries_with_cached_keys(self, q, k, seq_dim=None, offset=0):
|
| 191 |
+
seq_dim = default(seq_dim, self.default_seq_dim)
|
| 192 |
+
|
| 193 |
+
q_len, k_len = q.shape[seq_dim], k.shape[seq_dim]
|
| 194 |
+
assert q_len <= k_len
|
| 195 |
+
|
| 196 |
+
rotated_q = self.rotate_queries_or_keys(
|
| 197 |
+
q, seq_dim=seq_dim, offset=k_len - q_len + offset
|
| 198 |
+
)
|
| 199 |
+
rotated_k = self.rotate_queries_or_keys(k, seq_dim=seq_dim, offset=offset)
|
| 200 |
+
|
| 201 |
+
rotated_q = rotated_q.type(q.dtype)
|
| 202 |
+
rotated_k = rotated_k.type(k.dtype)
|
| 203 |
+
|
| 204 |
+
return rotated_q, rotated_k
|
| 205 |
+
|
| 206 |
+
def rotate_queries_and_keys(self, q, k, seq_dim=None):
|
| 207 |
+
seq_dim = default(seq_dim, self.default_seq_dim)
|
| 208 |
+
|
| 209 |
+
assert self.use_xpos
|
| 210 |
+
device, dtype, seq_len = q.device, q.dtype, q.shape[seq_dim]
|
| 211 |
+
|
| 212 |
+
seq = self.get_seq_pos(seq_len, dtype=dtype, device=device)
|
| 213 |
+
|
| 214 |
+
freqs = self.forward(seq, seq_len=seq_len)
|
| 215 |
+
scale = self.get_scale(seq, seq_len=seq_len).to(dtype)
|
| 216 |
+
|
| 217 |
+
if seq_dim == -3:
|
| 218 |
+
freqs = rearrange(freqs, "n d -> n 1 d")
|
| 219 |
+
scale = rearrange(scale, "n d -> n 1 d")
|
| 220 |
+
|
| 221 |
+
rotated_q = apply_rotary_emb(freqs, q, scale=scale, seq_dim=seq_dim)
|
| 222 |
+
rotated_k = apply_rotary_emb(freqs, k, scale=scale**-1, seq_dim=seq_dim)
|
| 223 |
+
|
| 224 |
+
rotated_q = rotated_q.type(q.dtype)
|
| 225 |
+
rotated_k = rotated_k.type(k.dtype)
|
| 226 |
+
|
| 227 |
+
return rotated_q, rotated_k
|
| 228 |
+
|
| 229 |
+
def get_scale(self, t: Tensor, seq_len: Optional[int] = None, offset=0):
|
| 230 |
+
assert self.use_xpos
|
| 231 |
+
|
| 232 |
+
should_cache = self.cache_if_possible and exists(seq_len)
|
| 233 |
+
|
| 234 |
+
if (
|
| 235 |
+
should_cache
|
| 236 |
+
and exists(self.cached_scales)
|
| 237 |
+
and (seq_len + offset) <= self.cached_scales.shape[0]
|
| 238 |
+
):
|
| 239 |
+
return self.cached_scales[offset : (offset + seq_len)]
|
| 240 |
+
|
| 241 |
+
scale = 1.0
|
| 242 |
+
if self.use_xpos:
|
| 243 |
+
power = (t - len(t) // 2) / self.scale_base
|
| 244 |
+
scale = self.scale ** rearrange(power, "n -> n 1")
|
| 245 |
+
scale = torch.cat((scale, scale), dim=-1)
|
| 246 |
+
|
| 247 |
+
if should_cache:
|
| 248 |
+
self.tmp_store("cached_scales", scale)
|
| 249 |
+
|
| 250 |
+
return scale
|
| 251 |
+
|
| 252 |
+
def get_axial_freqs(self, *dims):
|
| 253 |
+
Colon = slice(None)
|
| 254 |
+
all_freqs = []
|
| 255 |
+
|
| 256 |
+
for ind, dim in enumerate(dims):
|
| 257 |
+
if self.freqs_for == "pixel":
|
| 258 |
+
pos = torch.linspace(-1, 1, steps=dim, device=self.device)
|
| 259 |
+
else:
|
| 260 |
+
pos = torch.arange(dim, device=self.device)
|
| 261 |
+
|
| 262 |
+
freqs = self.forward(pos, seq_len=dim)
|
| 263 |
+
|
| 264 |
+
all_axis = [None] * len(dims)
|
| 265 |
+
all_axis[ind] = Colon
|
| 266 |
+
|
| 267 |
+
new_axis_slice = (Ellipsis, *all_axis, Colon)
|
| 268 |
+
all_freqs.append(freqs[new_axis_slice])
|
| 269 |
+
|
| 270 |
+
all_freqs = broadcast_tensors(*all_freqs)
|
| 271 |
+
return torch.cat(all_freqs, dim=-1)
|
| 272 |
+
|
| 273 |
+
@autocast("cuda", enabled=False)
|
| 274 |
+
def forward(self, t: Tensor, seq_len=None, offset=0):
|
| 275 |
+
should_cache = (
|
| 276 |
+
self.cache_if_possible
|
| 277 |
+
and not self.learned_freq
|
| 278 |
+
and exists(seq_len)
|
| 279 |
+
and self.freqs_for != "pixel"
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
if (
|
| 283 |
+
should_cache
|
| 284 |
+
and exists(self.cached_freqs)
|
| 285 |
+
and (offset + seq_len) <= self.cached_freqs.shape[0]
|
| 286 |
+
):
|
| 287 |
+
return self.cached_freqs[offset : (offset + seq_len)].detach()
|
| 288 |
+
|
| 289 |
+
freqs = self.freqs
|
| 290 |
+
|
| 291 |
+
freqs = einsum("..., f -> ... f", t.type(freqs.dtype), freqs)
|
| 292 |
+
freqs = repeat(freqs, "... n -> ... (n r)", r=2)
|
| 293 |
+
|
| 294 |
+
if should_cache:
|
| 295 |
+
self.tmp_store("cached_freqs", freqs.detach())
|
| 296 |
+
|
| 297 |
+
return freqs
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
class Rope2D:
|
| 304 |
+
""" Helper class to apply RoPE2D as well as interpolate on the fly. """
|
| 305 |
+
|
| 306 |
+
def __init__(self, dim, use_cls_token=False):
|
| 307 |
+
self.dim = dim
|
| 308 |
+
self.use_cls_token = use_cls_token
|
| 309 |
+
self.grid_size = None
|
| 310 |
+
self.freq = None
|
| 311 |
+
|
| 312 |
+
def init_tensors(self):
|
| 313 |
+
self.rope = RotaryEmbedding(self.dim // 2)
|
| 314 |
+
|
| 315 |
+
def update_grid(self, device, grid_h, grid_w):
|
| 316 |
+
if self.grid_size != (grid_h, grid_w):
|
| 317 |
+
self.grid_size = (grid_h, grid_w)
|
| 318 |
+
|
| 319 |
+
self.rope = self.rope.to(device)
|
| 320 |
+
|
| 321 |
+
if self.use_cls_token:
|
| 322 |
+
# +1 to leave space for the cls token to be (0, 0)
|
| 323 |
+
grid_y_range = torch.arange(grid_h, device=device) + 1
|
| 324 |
+
grid_x_range = torch.arange(grid_w, device=device) + 1
|
| 325 |
+
else:
|
| 326 |
+
grid_y_range = torch.arange(grid_h, device=device)
|
| 327 |
+
grid_x_range = torch.arange(grid_w, device=device)
|
| 328 |
+
|
| 329 |
+
freqs_y = self.rope(grid_y_range)[:, None].expand(grid_h, grid_w, -1)
|
| 330 |
+
freqs_x = self.rope(grid_x_range)[None, :].expand(grid_h, grid_w, -1)
|
| 331 |
+
freq = torch.cat([freqs_x, freqs_y], dim=-1).reshape(grid_h * grid_w, -1)
|
| 332 |
+
|
| 333 |
+
if self.use_cls_token:
|
| 334 |
+
freq = torch.cat(
|
| 335 |
+
[torch.zeros(1, freq.shape[-1], device=device), freq], dim=0
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
self.freq = freq[None, ...]
|
| 339 |
+
|
| 340 |
+
self.freq = self.freq.to(device)
|
| 341 |
+
|
| 342 |
+
def __call__(self, q, k):
|
| 343 |
+
# batch, heads, seq, dim = q.shape
|
| 344 |
+
q = apply_rotary_emb(self.freq[:, None, :, :], q)
|
| 345 |
+
k = apply_rotary_emb(self.freq[:, None, :, :], k)
|
| 346 |
+
|
| 347 |
+
return q, k
|
core/vision_encoder/tokenizer.py
ADDED
|
@@ -0,0 +1,342 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" CLIP tokenizer
|
| 2 |
+
|
| 3 |
+
Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import gzip
|
| 7 |
+
import html
|
| 8 |
+
import os
|
| 9 |
+
import random
|
| 10 |
+
import string
|
| 11 |
+
from functools import lru_cache, partial
|
| 12 |
+
from typing import Callable, List, Optional, Union
|
| 13 |
+
|
| 14 |
+
import ftfy
|
| 15 |
+
import regex as re
|
| 16 |
+
import torch
|
| 17 |
+
|
| 18 |
+
# https://stackoverflow.com/q/62691279
|
| 19 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 20 |
+
|
| 21 |
+
DEFAULT_CONTEXT_LENGTH = 77 # default context length for OpenAI CLIP
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@lru_cache()
|
| 25 |
+
def default_bpe():
|
| 26 |
+
return os.path.join(
|
| 27 |
+
os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz"
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@lru_cache()
|
| 32 |
+
def bytes_to_unicode():
|
| 33 |
+
"""
|
| 34 |
+
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
| 35 |
+
The reversible bpe codes work on unicode strings.
|
| 36 |
+
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
| 37 |
+
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
| 38 |
+
This is a significant percentage of your normal, say, 32K bpe vocab.
|
| 39 |
+
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
| 40 |
+
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
| 41 |
+
"""
|
| 42 |
+
bs = (
|
| 43 |
+
list(range(ord("!"), ord("~") + 1))
|
| 44 |
+
+ list(range(ord("¡"), ord("¬") + 1))
|
| 45 |
+
+ list(range(ord("®"), ord("ÿ") + 1))
|
| 46 |
+
)
|
| 47 |
+
cs = bs[:]
|
| 48 |
+
n = 0
|
| 49 |
+
for b in range(2**8):
|
| 50 |
+
if b not in bs:
|
| 51 |
+
bs.append(b)
|
| 52 |
+
cs.append(2**8 + n)
|
| 53 |
+
n += 1
|
| 54 |
+
cs = [chr(n) for n in cs]
|
| 55 |
+
return dict(zip(bs, cs))
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def get_pairs(word):
|
| 59 |
+
"""Return set of symbol pairs in a word.
|
| 60 |
+
Word is represented as tuple of symbols (symbols being variable-length strings).
|
| 61 |
+
"""
|
| 62 |
+
pairs = set()
|
| 63 |
+
prev_char = word[0]
|
| 64 |
+
for char in word[1:]:
|
| 65 |
+
pairs.add((prev_char, char))
|
| 66 |
+
prev_char = char
|
| 67 |
+
return pairs
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def basic_clean(text):
|
| 71 |
+
text = ftfy.fix_text(text)
|
| 72 |
+
text = html.unescape(html.unescape(text))
|
| 73 |
+
return text.strip()
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def whitespace_clean(text):
|
| 77 |
+
text = re.sub(r"\s+", " ", text)
|
| 78 |
+
text = text.strip()
|
| 79 |
+
return text
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def _clean_canonicalize(x):
|
| 83 |
+
# basic, remove whitespace, remove punctuation, lower case
|
| 84 |
+
return canonicalize_text(basic_clean(x))
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def _clean_lower(x):
|
| 88 |
+
# basic, remove whitespace, lower case
|
| 89 |
+
return whitespace_clean(basic_clean(x)).lower()
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def _clean_whitespace(x):
|
| 93 |
+
# basic, remove whitespace
|
| 94 |
+
return whitespace_clean(basic_clean(x))
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def get_clean_fn(type: str):
|
| 98 |
+
if type == "canonicalize":
|
| 99 |
+
return _clean_canonicalize
|
| 100 |
+
elif type == "lower":
|
| 101 |
+
return _clean_lower
|
| 102 |
+
elif type == "whitespace":
|
| 103 |
+
return _clean_whitespace
|
| 104 |
+
else:
|
| 105 |
+
assert False, f"Invalid clean function ({type})."
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def canonicalize_text(text, *, keep_punctuation_exact_string=None):
|
| 109 |
+
"""Returns canonicalized `text` (lowercase and punctuation removed).
|
| 110 |
+
|
| 111 |
+
From: https://github.com/google-research/big_vision/blob/53f18caf27a9419231bbf08d3388b07671616d3d/big_vision/evaluators/proj/image_text/prompt_engineering.py#L94
|
| 112 |
+
|
| 113 |
+
Args:
|
| 114 |
+
text: string to be canonicalized.
|
| 115 |
+
keep_punctuation_exact_string: If provided, then this exact string kept.
|
| 116 |
+
For example providing '{}' will keep any occurrences of '{}' (but will
|
| 117 |
+
still remove '{' and '}' that appear separately).
|
| 118 |
+
"""
|
| 119 |
+
text = text.replace("_", " ")
|
| 120 |
+
if keep_punctuation_exact_string:
|
| 121 |
+
text = keep_punctuation_exact_string.join(
|
| 122 |
+
part.translate(str.maketrans("", "", string.punctuation))
|
| 123 |
+
for part in text.split(keep_punctuation_exact_string)
|
| 124 |
+
)
|
| 125 |
+
else:
|
| 126 |
+
text = text.translate(str.maketrans("", "", string.punctuation))
|
| 127 |
+
text = text.lower()
|
| 128 |
+
text = re.sub(r"\s+", " ", text)
|
| 129 |
+
return text.strip()
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
class SimpleTokenizer(object):
|
| 133 |
+
def __init__(
|
| 134 |
+
self,
|
| 135 |
+
bpe_path: str = default_bpe(),
|
| 136 |
+
additional_special_tokens: Optional[List[str]] = None,
|
| 137 |
+
context_length: Optional[int] = DEFAULT_CONTEXT_LENGTH,
|
| 138 |
+
clean: str = "lower",
|
| 139 |
+
reduction_mask: str = "",
|
| 140 |
+
):
|
| 141 |
+
self.byte_encoder = bytes_to_unicode()
|
| 142 |
+
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
| 143 |
+
merges = gzip.open(bpe_path).read().decode("utf-8").split("\n")
|
| 144 |
+
merges = merges[1 : 49152 - 256 - 2 + 1]
|
| 145 |
+
merges = [tuple(merge.split()) for merge in merges]
|
| 146 |
+
vocab = list(bytes_to_unicode().values())
|
| 147 |
+
vocab = vocab + [v + "</w>" for v in vocab]
|
| 148 |
+
for merge in merges:
|
| 149 |
+
vocab.append("".join(merge))
|
| 150 |
+
special_tokens = ["<start_of_text>", "<end_of_text>"]
|
| 151 |
+
if additional_special_tokens:
|
| 152 |
+
special_tokens += additional_special_tokens
|
| 153 |
+
vocab.extend(special_tokens)
|
| 154 |
+
self.encoder = dict(zip(vocab, range(len(vocab))))
|
| 155 |
+
self.decoder = {v: k for k, v in self.encoder.items()}
|
| 156 |
+
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
| 157 |
+
self.cache = {t: t for t in special_tokens}
|
| 158 |
+
special = "|".join(special_tokens)
|
| 159 |
+
self.pat = re.compile(
|
| 160 |
+
special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
|
| 161 |
+
re.IGNORECASE,
|
| 162 |
+
)
|
| 163 |
+
self.vocab_size = len(self.encoder)
|
| 164 |
+
self.all_special_ids = [self.encoder[t] for t in special_tokens]
|
| 165 |
+
self.sot_token_id = self.all_special_ids[0]
|
| 166 |
+
self.eot_token_id = self.all_special_ids[1]
|
| 167 |
+
self.context_length = context_length
|
| 168 |
+
self.clean_fn = get_clean_fn(clean)
|
| 169 |
+
self.reduction_fn = (
|
| 170 |
+
get_reduction_mask_fn(reduction_mask) if reduction_mask else None
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
def bpe(self, token):
|
| 174 |
+
if token in self.cache:
|
| 175 |
+
return self.cache[token]
|
| 176 |
+
word = tuple(token[:-1]) + (token[-1] + "</w>",)
|
| 177 |
+
pairs = get_pairs(word)
|
| 178 |
+
|
| 179 |
+
if not pairs:
|
| 180 |
+
return token + "</w>"
|
| 181 |
+
|
| 182 |
+
while True:
|
| 183 |
+
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
|
| 184 |
+
if bigram not in self.bpe_ranks:
|
| 185 |
+
break
|
| 186 |
+
first, second = bigram
|
| 187 |
+
new_word = []
|
| 188 |
+
i = 0
|
| 189 |
+
while i < len(word):
|
| 190 |
+
try:
|
| 191 |
+
j = word.index(first, i)
|
| 192 |
+
new_word.extend(word[i:j])
|
| 193 |
+
i = j
|
| 194 |
+
except:
|
| 195 |
+
new_word.extend(word[i:])
|
| 196 |
+
break
|
| 197 |
+
|
| 198 |
+
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
|
| 199 |
+
new_word.append(first + second)
|
| 200 |
+
i += 2
|
| 201 |
+
else:
|
| 202 |
+
new_word.append(word[i])
|
| 203 |
+
i += 1
|
| 204 |
+
new_word = tuple(new_word)
|
| 205 |
+
word = new_word
|
| 206 |
+
if len(word) == 1:
|
| 207 |
+
break
|
| 208 |
+
else:
|
| 209 |
+
pairs = get_pairs(word)
|
| 210 |
+
word = " ".join(word)
|
| 211 |
+
self.cache[token] = word
|
| 212 |
+
return word
|
| 213 |
+
|
| 214 |
+
def encode(self, text):
|
| 215 |
+
bpe_tokens = []
|
| 216 |
+
text = self.clean_fn(text)
|
| 217 |
+
for token in re.findall(self.pat, text):
|
| 218 |
+
token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
|
| 219 |
+
bpe_tokens.extend(
|
| 220 |
+
self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ")
|
| 221 |
+
)
|
| 222 |
+
return bpe_tokens
|
| 223 |
+
|
| 224 |
+
def decode(self, tokens):
|
| 225 |
+
text = "".join([self.decoder[token] for token in tokens])
|
| 226 |
+
text = (
|
| 227 |
+
bytearray([self.byte_decoder[c] for c in text])
|
| 228 |
+
.decode("utf-8", errors="replace")
|
| 229 |
+
.replace("</w>", " ")
|
| 230 |
+
)
|
| 231 |
+
return text
|
| 232 |
+
|
| 233 |
+
def __call__(
|
| 234 |
+
self, texts: Union[str, List[str]], context_length: Optional[int] = None
|
| 235 |
+
) -> torch.LongTensor:
|
| 236 |
+
"""Returns the tokenized representation of given input string(s)
|
| 237 |
+
|
| 238 |
+
Parameters
|
| 239 |
+
----------
|
| 240 |
+
texts : Union[str, List[str]]
|
| 241 |
+
An input string or a list of input strings to tokenize
|
| 242 |
+
context_length : int
|
| 243 |
+
The context length to use; all CLIP models use 77 as the context length
|
| 244 |
+
|
| 245 |
+
Returns
|
| 246 |
+
-------
|
| 247 |
+
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
|
| 248 |
+
"""
|
| 249 |
+
if isinstance(texts, str):
|
| 250 |
+
texts = [texts]
|
| 251 |
+
|
| 252 |
+
context_length = context_length or self.context_length
|
| 253 |
+
assert context_length, "Please set a valid context length"
|
| 254 |
+
|
| 255 |
+
if self.reduction_fn is not None:
|
| 256 |
+
# use reduction strategy for tokenize if set, otherwise default to truncation below
|
| 257 |
+
return self.reduction_fn(
|
| 258 |
+
texts,
|
| 259 |
+
context_length=context_length,
|
| 260 |
+
sot_token_id=self.sot_token_id,
|
| 261 |
+
eot_token_id=self.eot_token_id,
|
| 262 |
+
encode_fn=self.encode,
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
all_tokens = [
|
| 266 |
+
[self.sot_token_id] + self.encode(text) + [self.eot_token_id]
|
| 267 |
+
for text in texts
|
| 268 |
+
]
|
| 269 |
+
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
|
| 270 |
+
|
| 271 |
+
for i, tokens in enumerate(all_tokens):
|
| 272 |
+
if len(tokens) > context_length:
|
| 273 |
+
tokens = tokens[:context_length] # Truncate
|
| 274 |
+
tokens[-1] = self.eot_token_id
|
| 275 |
+
result[i, : len(tokens)] = torch.tensor(tokens)
|
| 276 |
+
|
| 277 |
+
return result
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
def random_mask_tokenize(
|
| 281 |
+
texts: Union[str, List[str]],
|
| 282 |
+
context_length: int,
|
| 283 |
+
sot_token_id: int,
|
| 284 |
+
eot_token_id: int,
|
| 285 |
+
encode_fn: Callable,
|
| 286 |
+
shuffle: bool = False,
|
| 287 |
+
):
|
| 288 |
+
all_tokens = [encode_fn(text) for text in texts]
|
| 289 |
+
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
|
| 290 |
+
|
| 291 |
+
for i, tokens in enumerate(all_tokens):
|
| 292 |
+
tokens = torch.tensor(tokens)
|
| 293 |
+
num_tokens = len(tokens)
|
| 294 |
+
if num_tokens > context_length - 2: # 2 for sot and eot token
|
| 295 |
+
num_keep = context_length - 2
|
| 296 |
+
indices = torch.randperm(len(tokens))
|
| 297 |
+
indices = indices[:num_keep]
|
| 298 |
+
if not shuffle:
|
| 299 |
+
indices = indices.msort()
|
| 300 |
+
tokens = tokens[indices]
|
| 301 |
+
num_tokens = num_keep
|
| 302 |
+
result[i, 0] = sot_token_id
|
| 303 |
+
result[i, 1 : num_tokens + 1] = tokens
|
| 304 |
+
result[i, num_tokens + 1] = eot_token_id
|
| 305 |
+
|
| 306 |
+
return result
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
def simple_mask_tokenize(
|
| 310 |
+
texts: Union[str, List[str]],
|
| 311 |
+
context_length: int,
|
| 312 |
+
sot_token_id: int,
|
| 313 |
+
eot_token_id: int,
|
| 314 |
+
encode_fn: Callable,
|
| 315 |
+
):
|
| 316 |
+
all_tokens = [encode_fn(text) for text in texts]
|
| 317 |
+
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
|
| 318 |
+
|
| 319 |
+
for i, tokens in enumerate(all_tokens):
|
| 320 |
+
num_tokens = len(tokens)
|
| 321 |
+
if num_tokens > context_length - 2: # 2 for sot and eot token
|
| 322 |
+
num_keep = context_length - 2
|
| 323 |
+
start_index = random.randint(0, num_tokens - num_keep) # high is incl
|
| 324 |
+
tokens = tokens[start_index : start_index + num_keep]
|
| 325 |
+
tokens = [sot_token_id] + tokens + [eot_token_id]
|
| 326 |
+
result[i, : len(tokens)] = torch.tensor(tokens)
|
| 327 |
+
|
| 328 |
+
return result
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
def get_reduction_mask_fn(type: str):
|
| 333 |
+
"""Choose strategy for dropping (masking) tokens to achieve target context length"""
|
| 334 |
+
assert type in ("simple", "random", "shuffle")
|
| 335 |
+
if type == "simple":
|
| 336 |
+
return simple_mask_tokenize # randomly select block [start:end]
|
| 337 |
+
elif type == "random":
|
| 338 |
+
return random_mask_tokenize # randomly drop tokens (keep order)
|
| 339 |
+
elif type == "shuffle":
|
| 340 |
+
return partial(
|
| 341 |
+
random_mask_tokenize, shuffle=True
|
| 342 |
+
) # randomly drop tokens (shuffle order)
|
core/vision_encoder/transforms.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torchvision.transforms as T
|
| 2 |
+
|
| 3 |
+
from core.vision_encoder.tokenizer import SimpleTokenizer
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def get_image_transform(
|
| 7 |
+
image_size: int,
|
| 8 |
+
center_crop: bool = False,
|
| 9 |
+
interpolation: T.InterpolationMode = T.InterpolationMode.BILINEAR # We used bilinear during training
|
| 10 |
+
):
|
| 11 |
+
if center_crop:
|
| 12 |
+
crop = [
|
| 13 |
+
T.Resize(image_size, interpolation=interpolation),
|
| 14 |
+
T.CenterCrop(image_size)
|
| 15 |
+
]
|
| 16 |
+
else:
|
| 17 |
+
# "Squash": most versatile
|
| 18 |
+
crop = [
|
| 19 |
+
T.Resize((image_size, image_size), interpolation=interpolation)
|
| 20 |
+
]
|
| 21 |
+
|
| 22 |
+
return T.Compose(crop + [
|
| 23 |
+
T.Lambda(lambda x: x.convert("RGB")),
|
| 24 |
+
T.ToTensor(),
|
| 25 |
+
T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5], inplace=True),
|
| 26 |
+
])
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def get_text_tokenizer(context_length: int):
|
| 31 |
+
return SimpleTokenizer(context_length=context_length)
|
models.py
ADDED
|
@@ -0,0 +1,331 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Minimal model-loading code for the 7 VFM baselines in the paper."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import sys
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Callable
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
from torchvision import transforms
|
| 12 |
+
from transformers import AutoConfig, AutoImageProcessor, AutoModel
|
| 13 |
+
|
| 14 |
+
ROOT = Path(__file__).resolve().parent
|
| 15 |
+
WEIGHTS_DIR = ROOT / "weights"
|
| 16 |
+
|
| 17 |
+
MODEL_SPECS = {
|
| 18 |
+
"metacliplin": {
|
| 19 |
+
"paper_name": "MetaCLIP-Linear",
|
| 20 |
+
"checkpoint": "metacliplin0.pth",
|
| 21 |
+
"hf_model": "facebook/metaclip-h14-fullcc2.5b",
|
| 22 |
+
"feature_dim": 1280,
|
| 23 |
+
"image_size": 224,
|
| 24 |
+
"pooler_output": True,
|
| 25 |
+
},
|
| 26 |
+
"metaclip2lin": {
|
| 27 |
+
"paper_name": "MetaCLIP2-Linear",
|
| 28 |
+
"checkpoint": "metaclip2lin0.pth",
|
| 29 |
+
"hf_model": "facebook/metaclip-2-worldwide-giant",
|
| 30 |
+
"feature_dim": 1280,
|
| 31 |
+
"image_size": 224,
|
| 32 |
+
"pooler_output": True,
|
| 33 |
+
},
|
| 34 |
+
"sigliplin": {
|
| 35 |
+
"paper_name": "SigLIP-Linear",
|
| 36 |
+
"checkpoint": "sigliplin0.pth",
|
| 37 |
+
"hf_model": "google/siglip-large-patch16-384",
|
| 38 |
+
"feature_dim": 1024,
|
| 39 |
+
"image_size": 384,
|
| 40 |
+
"pooler_output": True,
|
| 41 |
+
},
|
| 42 |
+
"siglip2lin": {
|
| 43 |
+
"paper_name": "SigLIP2-Linear",
|
| 44 |
+
"checkpoint": "siglip2lin0.pth",
|
| 45 |
+
"hf_model": "google/siglip2-giant-opt-patch16-384",
|
| 46 |
+
"feature_dim": 1536,
|
| 47 |
+
"image_size": 384,
|
| 48 |
+
"pooler_output": True,
|
| 49 |
+
},
|
| 50 |
+
"pelin": {
|
| 51 |
+
"paper_name": "PE-CLIP-Linear",
|
| 52 |
+
"checkpoint": "pelin0.pth",
|
| 53 |
+
"feature_dim": 1024,
|
| 54 |
+
"image_size": 336,
|
| 55 |
+
"pooler_output": False,
|
| 56 |
+
},
|
| 57 |
+
"dinov2lin": {
|
| 58 |
+
"paper_name": "DINOv2-Linear",
|
| 59 |
+
"checkpoint": "dinov2lin0.pth",
|
| 60 |
+
"feature_dim": 1024,
|
| 61 |
+
"pooler_output": False,
|
| 62 |
+
},
|
| 63 |
+
"dinov3lin": {
|
| 64 |
+
"paper_name": "DINOv3-Linear",
|
| 65 |
+
"checkpoint": "dinov3lin0.pth",
|
| 66 |
+
"hf_model": "facebook/dinov3-vit7b16-pretrain-lvd1689m",
|
| 67 |
+
"feature_dim": 4096,
|
| 68 |
+
"pooler_output": False,
|
| 69 |
+
},
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
ALIASES = {
|
| 73 |
+
"MetaCLIP-Linear": "metacliplin",
|
| 74 |
+
"MetaCLIP2-Linear": "metaclip2lin",
|
| 75 |
+
"SigLIP-Linear": "sigliplin",
|
| 76 |
+
"SigLIP2-Linear": "siglip2lin",
|
| 77 |
+
"PE-CLIP-Linear": "pelin",
|
| 78 |
+
"DINOv2-Linear": "dinov2lin",
|
| 79 |
+
"DINOv3-Linear": "dinov3lin",
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def canonical_model_name(name: str) -> str:
|
| 84 |
+
if name in MODEL_SPECS:
|
| 85 |
+
return name
|
| 86 |
+
if name in ALIASES:
|
| 87 |
+
return ALIASES[name]
|
| 88 |
+
raise KeyError(f"Unknown model: {name}")
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def default_checkpoint_path(model_name: str) -> Path:
|
| 92 |
+
model_name = canonical_model_name(model_name)
|
| 93 |
+
return WEIGHTS_DIR / MODEL_SPECS[model_name]["checkpoint"]
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def _resolve_device(device: str | torch.device | None = None) -> torch.device:
|
| 97 |
+
if device is None:
|
| 98 |
+
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 99 |
+
return torch.device(device)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def _load_checkpoint(checkpoint_path: str | Path) -> dict:
|
| 103 |
+
checkpoint = torch.load(str(checkpoint_path), map_location="cpu", weights_only=False)
|
| 104 |
+
if isinstance(checkpoint, dict):
|
| 105 |
+
for key in ("state_dict", "model", "model_state_dict"):
|
| 106 |
+
if key in checkpoint and isinstance(checkpoint[key], dict):
|
| 107 |
+
checkpoint = checkpoint[key]
|
| 108 |
+
break
|
| 109 |
+
|
| 110 |
+
normalized = {}
|
| 111 |
+
for key, value in checkpoint.items():
|
| 112 |
+
normalized[key[7:] if key.startswith("module.") else key] = value
|
| 113 |
+
return normalized
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def _infer_feature_dim(state_dict: dict, default_dim: int) -> int:
|
| 117 |
+
head_weight = state_dict.get("head.weight")
|
| 118 |
+
if isinstance(head_weight, torch.Tensor) and head_weight.ndim == 2:
|
| 119 |
+
return int(head_weight.shape[1])
|
| 120 |
+
return default_dim
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def _load_image_processor(model_name: str):
|
| 124 |
+
try:
|
| 125 |
+
return AutoImageProcessor.from_pretrained(model_name, local_files_only=True)
|
| 126 |
+
except Exception:
|
| 127 |
+
try:
|
| 128 |
+
return AutoImageProcessor.from_pretrained(model_name)
|
| 129 |
+
except Exception:
|
| 130 |
+
return None
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def _load_backbone(model_name: str):
|
| 134 |
+
try:
|
| 135 |
+
return AutoModel.from_pretrained(model_name, local_files_only=True)
|
| 136 |
+
except Exception:
|
| 137 |
+
config = AutoConfig.from_pretrained(model_name)
|
| 138 |
+
return AutoModel.from_config(config)
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
class _PoolerLinearModel(nn.Module):
|
| 142 |
+
def __init__(self, backbone: nn.Module, feature_dim: int):
|
| 143 |
+
super().__init__()
|
| 144 |
+
self.backbone = backbone
|
| 145 |
+
self.head = nn.Linear(feature_dim, 2)
|
| 146 |
+
|
| 147 |
+
def forward(self, x):
|
| 148 |
+
with torch.no_grad():
|
| 149 |
+
outputs = self.backbone(x)
|
| 150 |
+
features = outputs.pooler_output.float()
|
| 151 |
+
return self.head(features)
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
class _ClsTokenLinearModel(nn.Module):
|
| 155 |
+
def __init__(self, backbone: nn.Module, feature_dim: int):
|
| 156 |
+
super().__init__()
|
| 157 |
+
self.backbone = backbone
|
| 158 |
+
self.head = nn.Linear(feature_dim, 2)
|
| 159 |
+
|
| 160 |
+
def forward(self, x):
|
| 161 |
+
with torch.no_grad():
|
| 162 |
+
outputs = self.backbone(x)
|
| 163 |
+
features = outputs.last_hidden_state[:, 0].float()
|
| 164 |
+
return self.head(features)
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
class _PELinearModel(nn.Module):
|
| 168 |
+
def __init__(self, backbone: nn.Module, feature_dim: int):
|
| 169 |
+
super().__init__()
|
| 170 |
+
self.backbone = backbone
|
| 171 |
+
self.head = nn.Linear(feature_dim, 2)
|
| 172 |
+
|
| 173 |
+
def forward(self, x):
|
| 174 |
+
with torch.no_grad():
|
| 175 |
+
features = self.backbone(x)
|
| 176 |
+
if isinstance(features, torch.Tensor):
|
| 177 |
+
features = features.float()
|
| 178 |
+
return self.head(features)
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def _finalize_model(model: nn.Module, state_dict: dict, device=None) -> nn.Module:
|
| 182 |
+
model.load_state_dict(state_dict, strict=False)
|
| 183 |
+
model.to(_resolve_device(device))
|
| 184 |
+
model.eval()
|
| 185 |
+
return model
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def _build_clip_transform(image_size: int, image_processor=None):
|
| 189 |
+
mean = [0.485, 0.456, 0.406]
|
| 190 |
+
std = [0.229, 0.224, 0.225]
|
| 191 |
+
if image_processor is not None:
|
| 192 |
+
mean = getattr(image_processor, "image_mean", mean)
|
| 193 |
+
std = getattr(image_processor, "image_std", std)
|
| 194 |
+
return transforms.Compose(
|
| 195 |
+
[
|
| 196 |
+
transforms.Resize(image_size, interpolation=transforms.InterpolationMode.BICUBIC),
|
| 197 |
+
transforms.CenterCrop(image_size),
|
| 198 |
+
transforms.ToTensor(),
|
| 199 |
+
transforms.Normalize(mean=mean, std=std),
|
| 200 |
+
]
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def _build_dino_transform():
|
| 205 |
+
return transforms.Compose(
|
| 206 |
+
[
|
| 207 |
+
transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC),
|
| 208 |
+
transforms.CenterCrop(224),
|
| 209 |
+
transforms.ToTensor(),
|
| 210 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
| 211 |
+
]
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def load_metacliplin(checkpoint_path: str | Path | None = None, device=None):
|
| 216 |
+
spec = MODEL_SPECS["metacliplin"]
|
| 217 |
+
checkpoint_path = checkpoint_path or default_checkpoint_path("metacliplin")
|
| 218 |
+
state_dict = _load_checkpoint(checkpoint_path)
|
| 219 |
+
feature_dim = _infer_feature_dim(state_dict, spec["feature_dim"])
|
| 220 |
+
image_processor = _load_image_processor(spec["hf_model"])
|
| 221 |
+
backbone = _load_backbone(spec["hf_model"])
|
| 222 |
+
model = _PoolerLinearModel(backbone.vision_model, feature_dim)
|
| 223 |
+
model = _finalize_model(model, state_dict, device=device)
|
| 224 |
+
return model, _build_clip_transform(spec["image_size"], image_processor)
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def load_metaclip2lin(checkpoint_path: str | Path | None = None, device=None):
|
| 228 |
+
spec = MODEL_SPECS["metaclip2lin"]
|
| 229 |
+
checkpoint_path = checkpoint_path or default_checkpoint_path("metaclip2lin")
|
| 230 |
+
state_dict = _load_checkpoint(checkpoint_path)
|
| 231 |
+
feature_dim = _infer_feature_dim(state_dict, spec["feature_dim"])
|
| 232 |
+
image_processor = _load_image_processor(spec["hf_model"])
|
| 233 |
+
backbone = _load_backbone(spec["hf_model"])
|
| 234 |
+
model = _PoolerLinearModel(backbone.vision_model, feature_dim)
|
| 235 |
+
model = _finalize_model(model, state_dict, device=device)
|
| 236 |
+
return model, _build_clip_transform(spec["image_size"], image_processor)
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def load_sigliplin(checkpoint_path: str | Path | None = None, device=None):
|
| 240 |
+
spec = MODEL_SPECS["sigliplin"]
|
| 241 |
+
checkpoint_path = checkpoint_path or default_checkpoint_path("sigliplin")
|
| 242 |
+
state_dict = _load_checkpoint(checkpoint_path)
|
| 243 |
+
feature_dim = _infer_feature_dim(state_dict, spec["feature_dim"])
|
| 244 |
+
image_processor = _load_image_processor(spec["hf_model"])
|
| 245 |
+
backbone = _load_backbone(spec["hf_model"])
|
| 246 |
+
model = _PoolerLinearModel(backbone.vision_model, feature_dim)
|
| 247 |
+
model = _finalize_model(model, state_dict, device=device)
|
| 248 |
+
return model, _build_clip_transform(spec["image_size"], image_processor)
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
def load_siglip2lin(checkpoint_path: str | Path | None = None, device=None):
|
| 252 |
+
spec = MODEL_SPECS["siglip2lin"]
|
| 253 |
+
checkpoint_path = checkpoint_path or default_checkpoint_path("siglip2lin")
|
| 254 |
+
state_dict = _load_checkpoint(checkpoint_path)
|
| 255 |
+
feature_dim = _infer_feature_dim(state_dict, spec["feature_dim"])
|
| 256 |
+
image_processor = _load_image_processor(spec["hf_model"])
|
| 257 |
+
backbone = _load_backbone(spec["hf_model"])
|
| 258 |
+
model = _PoolerLinearModel(backbone.vision_model, feature_dim)
|
| 259 |
+
model = _finalize_model(model, state_dict, device=device)
|
| 260 |
+
return model, _build_clip_transform(spec["image_size"], image_processor)
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
def load_dinov2lin(checkpoint_path: str | Path | None = None, device=None):
|
| 264 |
+
checkpoint_path = checkpoint_path or default_checkpoint_path("dinov2lin")
|
| 265 |
+
state_dict = _load_checkpoint(checkpoint_path)
|
| 266 |
+
feature_dim = _infer_feature_dim(state_dict, MODEL_SPECS["dinov2lin"]["feature_dim"])
|
| 267 |
+
if feature_dim == 1536:
|
| 268 |
+
candidates = ["facebook/dinov2-giant", "facebook/dinov2-large"]
|
| 269 |
+
elif feature_dim == 1024:
|
| 270 |
+
candidates = ["facebook/dinov2-large", "facebook/dinov2-base"]
|
| 271 |
+
elif feature_dim == 768:
|
| 272 |
+
candidates = ["facebook/dinov2-base", "facebook/dinov2-small"]
|
| 273 |
+
else:
|
| 274 |
+
candidates = ["facebook/dinov2-large"]
|
| 275 |
+
|
| 276 |
+
last_error = None
|
| 277 |
+
backbone = None
|
| 278 |
+
for candidate in candidates:
|
| 279 |
+
try:
|
| 280 |
+
backbone = _load_backbone(candidate)
|
| 281 |
+
break
|
| 282 |
+
except Exception as exc:
|
| 283 |
+
last_error = exc
|
| 284 |
+
if backbone is None:
|
| 285 |
+
raise RuntimeError(f"Failed to load DINOv2 backbone: {last_error}")
|
| 286 |
+
|
| 287 |
+
model = _ClsTokenLinearModel(backbone, feature_dim)
|
| 288 |
+
model = _finalize_model(model, state_dict, device=device)
|
| 289 |
+
return model, _build_dino_transform()
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
def load_dinov3lin(checkpoint_path: str | Path | None = None, device=None):
|
| 293 |
+
checkpoint_path = checkpoint_path or default_checkpoint_path("dinov3lin")
|
| 294 |
+
state_dict = _load_checkpoint(checkpoint_path)
|
| 295 |
+
feature_dim = _infer_feature_dim(state_dict, MODEL_SPECS["dinov3lin"]["feature_dim"])
|
| 296 |
+
backbone = _load_backbone(MODEL_SPECS["dinov3lin"]["hf_model"])
|
| 297 |
+
model = _ClsTokenLinearModel(backbone, feature_dim)
|
| 298 |
+
model = _finalize_model(model, state_dict, device=device)
|
| 299 |
+
return model, _build_dino_transform()
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
def load_pelin(checkpoint_path: str | Path | None = None, device=None):
|
| 303 |
+
checkpoint_path = checkpoint_path or default_checkpoint_path("pelin")
|
| 304 |
+
if str(ROOT) not in sys.path:
|
| 305 |
+
sys.path.insert(0, str(ROOT))
|
| 306 |
+
|
| 307 |
+
import core.vision_encoder.pe as pe
|
| 308 |
+
import core.vision_encoder.transforms as pe_transforms
|
| 309 |
+
|
| 310 |
+
state_dict = _load_checkpoint(checkpoint_path)
|
| 311 |
+
feature_dim = _infer_feature_dim(state_dict, MODEL_SPECS["pelin"]["feature_dim"])
|
| 312 |
+
clip_model = pe.CLIP.from_config("PE-Core-L14-336", pretrained=False)
|
| 313 |
+
model = _PELinearModel(clip_model.visual, feature_dim)
|
| 314 |
+
model = _finalize_model(model, state_dict, device=device)
|
| 315 |
+
return model, pe_transforms.get_image_transform(MODEL_SPECS["pelin"]["image_size"])
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
LOADERS: dict[str, Callable] = {
|
| 319 |
+
"metacliplin": load_metacliplin,
|
| 320 |
+
"metaclip2lin": load_metaclip2lin,
|
| 321 |
+
"sigliplin": load_sigliplin,
|
| 322 |
+
"siglip2lin": load_siglip2lin,
|
| 323 |
+
"pelin": load_pelin,
|
| 324 |
+
"dinov2lin": load_dinov2lin,
|
| 325 |
+
"dinov3lin": load_dinov3lin,
|
| 326 |
+
}
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
def load_model(model_name: str, checkpoint_path: str | Path | None = None, device=None):
|
| 330 |
+
model_name = canonical_model_name(model_name)
|
| 331 |
+
return LOADERS[model_name](checkpoint_path=checkpoint_path, device=device)
|
test_vfm_baselines.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Unified evaluation script for the 7 VFM baselines."""
|
| 3 |
+
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
import argparse
|
| 7 |
+
import json
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
from PIL import Image
|
| 13 |
+
from sklearn.metrics import accuracy_score, average_precision_score, roc_auc_score
|
| 14 |
+
from torch.utils.data import DataLoader, Dataset
|
| 15 |
+
|
| 16 |
+
from models import LOADERS, MODEL_SPECS, canonical_model_name, default_checkpoint_path, load_model
|
| 17 |
+
|
| 18 |
+
IMAGE_EXTENSIONS = (".jpg", ".jpeg", ".png", ".bmp", ".JPG", ".JPEG", ".PNG")
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class BinaryFolderDataset(Dataset):
|
| 22 |
+
def __init__(self, real_dir: str, fake_dir: str, transform, max_samples: int | None = None):
|
| 23 |
+
self.transform = transform
|
| 24 |
+
real_paths = self._get_image_files(real_dir)
|
| 25 |
+
fake_paths = self._get_image_files(fake_dir)
|
| 26 |
+
if max_samples is not None:
|
| 27 |
+
real_paths = real_paths[:max_samples]
|
| 28 |
+
fake_paths = fake_paths[:max_samples]
|
| 29 |
+
self.image_paths = real_paths + fake_paths
|
| 30 |
+
self.labels = [0] * len(real_paths) + [1] * len(fake_paths)
|
| 31 |
+
|
| 32 |
+
@staticmethod
|
| 33 |
+
def _get_image_files(folder: str):
|
| 34 |
+
folder = Path(folder)
|
| 35 |
+
images = []
|
| 36 |
+
for extension in IMAGE_EXTENSIONS:
|
| 37 |
+
images.extend(folder.rglob(f"*{extension}"))
|
| 38 |
+
return sorted(images)
|
| 39 |
+
|
| 40 |
+
def __len__(self):
|
| 41 |
+
return len(self.image_paths)
|
| 42 |
+
|
| 43 |
+
def __getitem__(self, index):
|
| 44 |
+
image_path = self.image_paths[index]
|
| 45 |
+
image = Image.open(image_path).convert("RGB")
|
| 46 |
+
return self.transform(image), self.labels[index], str(image_path)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def evaluate(model, transform, real_dir: str, fake_dir: str, batch_size: int, num_workers: int, max_samples: int | None):
|
| 50 |
+
dataset = BinaryFolderDataset(real_dir, fake_dir, transform, max_samples=max_samples)
|
| 51 |
+
dataloader = DataLoader(
|
| 52 |
+
dataset,
|
| 53 |
+
batch_size=batch_size,
|
| 54 |
+
shuffle=False,
|
| 55 |
+
num_workers=num_workers,
|
| 56 |
+
pin_memory=torch.cuda.is_available(),
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
device = next(model.parameters()).device
|
| 60 |
+
y_true = []
|
| 61 |
+
y_prob = []
|
| 62 |
+
y_pred = []
|
| 63 |
+
paths = []
|
| 64 |
+
|
| 65 |
+
with torch.no_grad():
|
| 66 |
+
for images, labels, batch_paths in dataloader:
|
| 67 |
+
images = images.to(device)
|
| 68 |
+
logits = model(images)
|
| 69 |
+
probs = torch.softmax(logits, dim=1)[:, 1].cpu().numpy()
|
| 70 |
+
preds = (probs > 0.5).astype(int)
|
| 71 |
+
|
| 72 |
+
y_true.extend(labels.numpy().tolist())
|
| 73 |
+
y_prob.extend(probs.tolist())
|
| 74 |
+
y_pred.extend(preds.tolist())
|
| 75 |
+
paths.extend(batch_paths)
|
| 76 |
+
|
| 77 |
+
y_true = np.asarray(y_true)
|
| 78 |
+
y_prob = np.asarray(y_prob)
|
| 79 |
+
y_pred = np.asarray(y_pred)
|
| 80 |
+
|
| 81 |
+
metrics = {
|
| 82 |
+
"accuracy": float(accuracy_score(y_true, y_pred)),
|
| 83 |
+
"real_accuracy": float(accuracy_score(y_true[y_true == 0], y_pred[y_true == 0])),
|
| 84 |
+
"fake_accuracy": float(accuracy_score(y_true[y_true == 1], y_pred[y_true == 1])),
|
| 85 |
+
}
|
| 86 |
+
if len(np.unique(y_true)) > 1:
|
| 87 |
+
metrics["auc"] = float(roc_auc_score(y_true, y_prob))
|
| 88 |
+
metrics["ap"] = float(average_precision_score(y_true, y_prob))
|
| 89 |
+
|
| 90 |
+
samples = [
|
| 91 |
+
{
|
| 92 |
+
"path": path,
|
| 93 |
+
"label": int(label),
|
| 94 |
+
"prob_fake": float(prob),
|
| 95 |
+
"pred": int(pred),
|
| 96 |
+
}
|
| 97 |
+
for path, label, prob, pred in zip(paths, y_true, y_prob, y_pred)
|
| 98 |
+
]
|
| 99 |
+
return {"metrics": metrics, "samples": samples}
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def main():
|
| 103 |
+
parser = argparse.ArgumentParser()
|
| 104 |
+
parser.add_argument("--model", default="all", help="One of: all, metacliplin, metaclip2lin, sigliplin, siglip2lin, pelin, dinov2lin, dinov3lin")
|
| 105 |
+
parser.add_argument("--real-dir", required=True)
|
| 106 |
+
parser.add_argument("--fake-dir", required=True)
|
| 107 |
+
parser.add_argument("--checkpoint", default=None, help="Optional explicit checkpoint path for single-model evaluation")
|
| 108 |
+
parser.add_argument("--batch-size", type=int, default=8)
|
| 109 |
+
parser.add_argument("--num-workers", type=int, default=4)
|
| 110 |
+
parser.add_argument("--max-samples", type=int, default=None)
|
| 111 |
+
parser.add_argument("--device", default=None)
|
| 112 |
+
parser.add_argument("--save-json", default=None)
|
| 113 |
+
args = parser.parse_args()
|
| 114 |
+
|
| 115 |
+
model_names = list(LOADERS.keys()) if args.model == "all" else [canonical_model_name(args.model)]
|
| 116 |
+
results = {}
|
| 117 |
+
|
| 118 |
+
for model_name in model_names:
|
| 119 |
+
checkpoint = args.checkpoint if args.model != "all" and args.checkpoint else default_checkpoint_path(model_name)
|
| 120 |
+
checkpoint = Path(checkpoint)
|
| 121 |
+
try:
|
| 122 |
+
checkpoint_for_output = str(checkpoint.relative_to(Path(__file__).resolve().parent))
|
| 123 |
+
except ValueError:
|
| 124 |
+
checkpoint_for_output = str(checkpoint)
|
| 125 |
+
model, transform = load_model(model_name, checkpoint_path=checkpoint, device=args.device)
|
| 126 |
+
result = evaluate(
|
| 127 |
+
model=model,
|
| 128 |
+
transform=transform,
|
| 129 |
+
real_dir=args.real_dir,
|
| 130 |
+
fake_dir=args.fake_dir,
|
| 131 |
+
batch_size=args.batch_size,
|
| 132 |
+
num_workers=args.num_workers,
|
| 133 |
+
max_samples=args.max_samples,
|
| 134 |
+
)
|
| 135 |
+
results[model_name] = {
|
| 136 |
+
"paper_name": MODEL_SPECS[model_name]["paper_name"],
|
| 137 |
+
"checkpoint": checkpoint_for_output,
|
| 138 |
+
**result,
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
del model
|
| 142 |
+
if torch.cuda.is_available():
|
| 143 |
+
torch.cuda.empty_cache()
|
| 144 |
+
|
| 145 |
+
output = json.dumps(results, indent=2, ensure_ascii=False)
|
| 146 |
+
print(output)
|
| 147 |
+
|
| 148 |
+
if args.save_json:
|
| 149 |
+
Path(args.save_json).write_text(output + "\n", encoding="utf-8")
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
if __name__ == "__main__":
|
| 153 |
+
main()
|
weights/dinov2lin0.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3c8604c137ad296d9f6bbd239d03e792cca36b2503eb03cebc5ccb5abf740ebe
|
| 3 |
+
size 4546228799
|
weights/dinov3lin0.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:58e35c23fc4e6a279dadedac8191b4a409a760fbdc43837af9f8541a6f7b2fb9
|
| 3 |
+
size 26864441175
|
weights/metaclip2lin0.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e609f38a74ad280abe48dd0dc0111ef113f5f1cd8c4a3337a8346a22afbc5258
|
| 3 |
+
size 3685870062
|
weights/metacliplin0.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2deed4e9d96cb5f27df0579c57028e9b855162f3daf7326ec402b580e244194c
|
| 3 |
+
size 1261744353
|
weights/pelin0.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d871b813db9d1cf335ba0cde1701c72baf7bc80369238d881c55a676a59b24ff
|
| 3 |
+
size 1268731407
|
weights/siglip2lin0.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c6971aab31d8ff4c061f4ae330b49f38c3f45a96bfb99e756f65af72e8f3f3b7
|
| 3 |
+
size 2327586086
|
weights/sigliplin0.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:eaf932895087718f284d880fccbc565cfafbed7b2a6c12b67a346e6c878c8ab3
|
| 3 |
+
size 632730704
|