Lunahera commited on
Commit
ce8f665
·
verified ·
1 Parent(s): 20476f4

Initial upload of simplicityprevails from local project

Browse files
README.md ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # VFM Baselines Release
2
+
3
+ This directory contains the 7 vision foundation model baselines used in the paper:
4
+
5
+ - `MetaCLIP-Linear`
6
+ - `MetaCLIP2-Linear`
7
+ - `SigLIP-Linear`
8
+ - `SigLIP2-Linear`
9
+ - `PE-CLIP-Linear`
10
+ - `DINOv2-Linear`
11
+ - `DINOv3-Linear`
12
+
13
+ ## Contents
14
+
15
+ - `models.py`: unified model-loading code for all 7 baselines
16
+ - `test_vfm_baselines.py`: unified evaluation script
17
+ - `weights/`: released checkpoints
18
+ - `core/vision_encoder/`: vendored PE vision encoder code required by `PE-CLIP-Linear`
19
+
20
+ ## Model Names
21
+
22
+ The unified loader and test script accept these names:
23
+
24
+ - `metacliplin`
25
+ - `metaclip2lin`
26
+ - `sigliplin`
27
+ - `siglip2lin`
28
+ - `pelin`
29
+ - `dinov2lin`
30
+ - `dinov3lin`
31
+
32
+ The paper names such as `MetaCLIP-Linear` and `DINOv3-Linear` are also accepted.
33
+
34
+ ## Usage
35
+
36
+ Evaluate a single model:
37
+
38
+ ```bash
39
+ python test_vfm_baselines.py \
40
+ --model sigliplin \
41
+ --real-dir /path/to/0_real \
42
+ --fake-dir /path/to/1_fake \
43
+ --max-samples 100
44
+ ```
45
+
46
+ Evaluate all 7 models:
47
+
48
+ ```bash
49
+ python test_vfm_baselines.py \
50
+ --model all \
51
+ --real-dir /path/to/0_real \
52
+ --fake-dir /path/to/1_fake \
53
+ --max-samples 100
54
+ ```
55
+
56
+ Optional arguments:
57
+
58
+ - `--checkpoint`: override the default checkpoint for single-model evaluation
59
+ - `--batch-size`: batch size for evaluation
60
+ - `--num-workers`: dataloader workers
61
+ - `--device`: explicit device such as `cuda:0` or `cpu`
62
+ - `--save-json`: save results to a JSON file
63
+
64
+ ## Dependencies
65
+
66
+ The release code expects these Python packages:
67
+
68
+ - `torch`
69
+ - `torchvision`
70
+ - `transformers`
71
+ - `scikit-learn`
72
+ - `Pillow`
73
+ - `timm`
74
+ - `einops`
75
+ - `ftfy`
76
+ - `regex`
77
+ - `huggingface_hub`
78
+
79
+ ## Notes
80
+
81
+ - The clip-family and DINO-family baselines instantiate the backbone from Hugging Face model configs and then load the released checkpoint.
82
+ - `PE-CLIP-Linear` uses the vendored `core/vision_encoder` code in this directory.
83
+ - The checkpoints in `weights/` are arranged locally for packaging convenience. For public release, they can be uploaded as the same filenames.
core/__init__.py ADDED
File without changes
core/vision_encoder/__init__.py ADDED
File without changes
core/vision_encoder/bpe_simple_vocab_16e6.txt.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
3
+ size 1356917
core/vision_encoder/config.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ """
4
+ Include all available vision encoder configurations.
5
+ """
6
+
7
+ from dataclasses import dataclass, replace
8
+
9
+ from functools import partial
10
+ from typing import Callable, Optional, Sequence, Tuple, List
11
+
12
+ from huggingface_hub import hf_hub_download
13
+
14
+
15
+
16
+ def fetch_pe_checkpoint(name: str, path: Optional[str] = None):
17
+ path = path or f"hf://facebook/{name}:{name}.pt"
18
+
19
+ if path.startswith("hf://"):
20
+ # Load from huggingface
21
+ path = path[len("hf://"):]
22
+ repo, file = path.split(":")
23
+
24
+ return hf_hub_download(repo_id=repo, filename=file)
25
+ else:
26
+ return path
27
+
28
+
29
+
30
+
31
+ @dataclass
32
+ class PEConfig:
33
+ """ Vision Tower Config. """
34
+ patch_size: int
35
+ width: int
36
+ layers: int
37
+ heads: int
38
+ mlp_ratio: float
39
+ output_dim: Optional[int]
40
+
41
+ ls_init_value: float = None
42
+ drop_path: float = 0.0
43
+
44
+ image_size: int = 224,
45
+ use_abs_posemb: bool = True
46
+ use_cls_token: bool = False
47
+ use_rope2d: bool = True
48
+
49
+ pool_type: str = "attn"
50
+ attn_pooler_heads: int = 8
51
+
52
+ use_ln_pre: bool = True
53
+ use_ln_post: bool = True
54
+
55
+
56
+ @dataclass
57
+ class PETextConfig:
58
+ """ Text Tower Config. """
59
+ context_length: int
60
+ width: int
61
+ heads: int
62
+ layers: int
63
+
64
+ output_dim: int
65
+
66
+ mlp_ratio: float = 4.0
67
+ vocab_size: int = 49408
68
+
69
+
70
+
71
+
72
+ PE_VISION_CONFIG = {}
73
+ PE_TEXT_CONFIG = {}
74
+
75
+
76
+
77
+ #########################################
78
+ # PE CORE #
79
+ #########################################
80
+
81
+ PE_VISION_CONFIG["PE-Core-G14-448"] = PEConfig(
82
+ image_size=448,
83
+ patch_size=14,
84
+ width=1536,
85
+ layers=50,
86
+ heads=16,
87
+ mlp_ratio=8960 / 1536,
88
+ pool_type="attn",
89
+ output_dim=1280,
90
+ use_cls_token=False,
91
+ )
92
+ PE_TEXT_CONFIG["PE-Core-G14-448"] = PETextConfig(
93
+ context_length=72,
94
+ width=1280,
95
+ heads=20,
96
+ layers=24,
97
+ output_dim=1280
98
+ )
99
+
100
+
101
+ PE_VISION_CONFIG["PE-Core-L14-336"] = PEConfig(
102
+ image_size=336,
103
+ patch_size=14,
104
+ width=1024,
105
+ layers=24,
106
+ heads=16,
107
+ mlp_ratio=4.0,
108
+ pool_type="attn",
109
+ output_dim=1024,
110
+ use_cls_token=True,
111
+ )
112
+ PE_TEXT_CONFIG["PE-Core-L14-336"] = PETextConfig(
113
+ context_length=32,
114
+ width=1024,
115
+ heads=16,
116
+ layers=24,
117
+ output_dim=1024
118
+ )
119
+
120
+
121
+ PE_VISION_CONFIG["PE-Core-B16-224"] = PEConfig(
122
+ image_size=224,
123
+ patch_size=16,
124
+ width=768,
125
+ layers=12,
126
+ heads=12,
127
+ mlp_ratio=4.0,
128
+ pool_type="attn",
129
+ output_dim=1024,
130
+ use_cls_token=True,
131
+ )
132
+ PE_TEXT_CONFIG["PE-Core-B16-224"] = PE_TEXT_CONFIG["PE-Core-L14-336"]
133
+
134
+
135
+
136
+
137
+ PE_VISION_CONFIG["PE-Core-S16-384"] = PEConfig(
138
+ image_size=384,
139
+ patch_size=16,
140
+ width=384,
141
+ layers=12,
142
+ heads=6,
143
+ mlp_ratio=4.0,
144
+ pool_type="attn",
145
+ output_dim=512,
146
+ use_cls_token=True,
147
+ )
148
+ PE_TEXT_CONFIG["PE-Core-S16-384"] = PETextConfig(
149
+ context_length=32,
150
+ width=512,
151
+ heads=8,
152
+ layers=12,
153
+ output_dim=512
154
+ )
155
+
156
+
157
+
158
+ PE_VISION_CONFIG["PE-Core-T16-384"] = PEConfig(
159
+ image_size=384,
160
+ patch_size=16,
161
+ width=192,
162
+ layers=12,
163
+ heads=3,
164
+ mlp_ratio=4.0,
165
+ pool_type="attn",
166
+ output_dim=512,
167
+ use_cls_token=True,
168
+ )
169
+ PE_TEXT_CONFIG["PE-Core-T16-384"] = PE_TEXT_CONFIG["PE-Core-S16-384"]
170
+
171
+
172
+
173
+
174
+
175
+
176
+
177
+ #########################################
178
+ # PE Lang #
179
+ #########################################
180
+
181
+ PE_VISION_CONFIG["PE-Lang-G14-448"] = replace(
182
+ PE_VISION_CONFIG["PE-Core-G14-448"],
183
+ image_size=448,
184
+ pool_type="none",
185
+ use_ln_post=False,
186
+ output_dim=None,
187
+ ls_init_value=0.1,
188
+ layers=47,
189
+ )
190
+
191
+ PE_VISION_CONFIG["PE-Lang-L14-448"] = replace(
192
+ PE_VISION_CONFIG["PE-Core-L14-336"],
193
+ image_size=448,
194
+ pool_type="none",
195
+ use_ln_post=False,
196
+ output_dim=None,
197
+ ls_init_value=0.1,
198
+ layers=23
199
+ )
200
+
201
+
202
+ # Stage 2 checkpoints for PLM-8B and PLM-3B respectively. Pretrained with tiling.
203
+ # Use these checkpoints if you're building a model that uses tiling downstream!
204
+ PE_VISION_CONFIG["PE-Lang-G14-448-Tiling"] = PE_VISION_CONFIG["PE-Lang-G14-448"]
205
+ PE_VISION_CONFIG["PE-Lang-L14-448-Tiling"] = PE_VISION_CONFIG["PE-Lang-L14-448"]
206
+
207
+
208
+
209
+
210
+
211
+
212
+
213
+
214
+ #########################################
215
+ # PE Spatial #
216
+ #########################################
217
+
218
+ PE_VISION_CONFIG["PE-Spatial-G14-448"] = replace(
219
+ PE_VISION_CONFIG["PE-Core-G14-448"],
220
+ image_size=448,
221
+ pool_type="none",
222
+ use_ln_post=False,
223
+ output_dim=None,
224
+ ls_init_value=0.1,
225
+ )
226
+
227
+ # No layerscale on the smaller spatial models
228
+ PE_VISION_CONFIG["PE-Spatial-L14-448"] = replace(
229
+ PE_VISION_CONFIG["PE-Core-L14-336"],
230
+ image_size=448,
231
+ pool_type="none",
232
+ use_ln_post=False,
233
+ output_dim=None,
234
+ )
235
+
236
+
237
+ PE_VISION_CONFIG["PE-Spatial-B16-512"] = replace(
238
+ PE_VISION_CONFIG["PE-Core-B16-224"],
239
+ image_size=512,
240
+ pool_type="none",
241
+ use_ln_post=False,
242
+ output_dim=None,
243
+ )
244
+
245
+
246
+ PE_VISION_CONFIG["PE-Spatial-S16-512"] = replace(
247
+ PE_VISION_CONFIG["PE-Core-S16-384"],
248
+ image_size=512,
249
+ pool_type="none",
250
+ use_ln_post=False,
251
+ output_dim=None,
252
+ )
253
+
254
+
255
+ PE_VISION_CONFIG["PE-Spatial-T16-512"] = replace(
256
+ PE_VISION_CONFIG["PE-Core-T16-384"],
257
+ image_size=512,
258
+ pool_type="none",
259
+ use_ln_post=False,
260
+ output_dim=None,
261
+ )
core/vision_encoder/pe.py ADDED
@@ -0,0 +1,761 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import math
3
+ import random
4
+ from collections import OrderedDict
5
+ from dataclasses import asdict
6
+ from functools import partial
7
+ from logging import getLogger
8
+ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, Literal
9
+
10
+ import numpy as np
11
+ import torch
12
+ import torch.nn as nn
13
+ from einops import rearrange
14
+ from timm.layers import DropPath
15
+ from torch import nn
16
+ from torch.nn import functional as F
17
+ from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
18
+ from torch.nn.parameter import Parameter
19
+ from torch.utils.checkpoint import checkpoint
20
+
21
+ from core.vision_encoder.rope import Rope2D
22
+ from core.vision_encoder.config import PEConfig, PETextConfig, PE_VISION_CONFIG, PE_TEXT_CONFIG, fetch_pe_checkpoint
23
+
24
+
25
+
26
+ logger = getLogger()
27
+
28
+
29
+
30
+ class LayerScale(nn.Module):
31
+ def __init__(self, dim, init_values=1e-5, inplace=False):
32
+ super().__init__()
33
+ self.inplace = inplace
34
+ self.dim = dim
35
+ self.init_values = init_values
36
+
37
+ def forward(self, x):
38
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
39
+
40
+ def init_tensors(self):
41
+ self.gamma = nn.Parameter(self.init_values * torch.ones(self.dim))
42
+
43
+
44
+ class AttentionPooling(nn.Module):
45
+ def __init__(
46
+ self,
47
+ embed_dim: int,
48
+ num_heads: int,
49
+ num_probe: int = 1,
50
+ mlp_ratio: int = 4,
51
+ act_layer: Callable = nn.GELU,
52
+ norm_layer: Callable = nn.LayerNorm,
53
+ ):
54
+ super().__init__()
55
+
56
+ self.embed_dim = embed_dim
57
+ self.num_heads = num_heads
58
+
59
+ assert (
60
+ self.embed_dim % num_heads == 0
61
+ ), "embed_dim must be divisible by num_heads"
62
+
63
+ self.probe = nn.Parameter(torch.randn(1, num_probe, self.embed_dim))
64
+ self.attn = nn.MultiheadAttention(
65
+ self.embed_dim, self.num_heads, batch_first=True
66
+ )
67
+
68
+ self.layernorm = norm_layer(embed_dim)
69
+ self.mlp_width = int(embed_dim * mlp_ratio)
70
+ self.mlp = nn.Sequential(
71
+ OrderedDict(
72
+ [
73
+ ("c_fc", nn.Linear(self.embed_dim, self.mlp_width)),
74
+ ("gelu", act_layer()),
75
+ ("c_proj", nn.Linear(self.mlp_width, self.embed_dim)),
76
+ ]
77
+ )
78
+ )
79
+
80
+ def forward(self, x: torch.Tensor):
81
+ batch, _, _ = x.shape
82
+
83
+ q = self.probe.repeat((batch, 1, 1)).to(x.dtype)
84
+ x = self.attn(q, x, x, need_weights=False)[0]
85
+ x = x + self.mlp(self.layernorm(x))
86
+
87
+ return x
88
+
89
+
90
+ class SelfAttention(nn.Module):
91
+ r"""
92
+ Implements sequence packed attention and RoPe
93
+ """
94
+
95
+ def __init__(
96
+ self,
97
+ embed_dim: int,
98
+ num_heads: int,
99
+ rope: Optional[nn.Module] = None,
100
+ ):
101
+ super(SelfAttention, self).__init__()
102
+ self.embed_dim = embed_dim
103
+
104
+ self.num_heads = num_heads
105
+ self.head_dim = embed_dim // num_heads
106
+ assert (
107
+ self.head_dim * num_heads == self.embed_dim
108
+ ), "embed_dim must be divisible by num_heads"
109
+
110
+ # To make this compatibile with nn.MultiHeadAttention
111
+ self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim))
112
+ self.in_proj_bias = Parameter(torch.empty(3 * embed_dim))
113
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
114
+
115
+ self.rope = rope
116
+ self.scale = self.head_dim ** (-0.5)
117
+
118
+ def init_tensors(self):
119
+ xavier_uniform_(self.in_proj_weight)
120
+ constant_(self.in_proj_bias, 0.0)
121
+ constant_(self.out_proj.bias, 0.0)
122
+
123
+ def forward(self, x, attn_mask=None):
124
+ batch, seq, embed_dim = x.shape
125
+ proj = F.linear(x, self.in_proj_weight, self.in_proj_bias)
126
+
127
+ # reshape to 3, E and not E, 3 is deliberate for better memory coalescing and keeping same order as chunk()
128
+ proj = (
129
+ proj.unflatten(-1, (3, embed_dim))
130
+ .unsqueeze(0)
131
+ .transpose(0, -2)
132
+ .squeeze(-2)
133
+ .contiguous()
134
+ )
135
+ q, k, v = proj[0], proj[1], proj[2]
136
+
137
+ # Use "q_" so that we don't accidentally quit in pdb :)
138
+ q = rearrange(q, "b s (h d) -> b h s d", h=self.num_heads)
139
+ k = rearrange(k, "b s (h d) -> b h s d", h=self.num_heads)
140
+ v = rearrange(v, "b s (h d) -> b h s d", h=self.num_heads)
141
+
142
+ if self.rope:
143
+ q, k = self.rope(q, k)
144
+
145
+ attn = F.scaled_dot_product_attention(
146
+ q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False, scale=self.scale
147
+ )
148
+ attn = rearrange(attn, "b h s d -> b s (h d)")
149
+
150
+ return F.linear(attn, self.out_proj.weight, self.out_proj.bias)
151
+
152
+
153
+ class ResidualAttentionBlock(nn.Module):
154
+ def __init__(
155
+ self,
156
+ d_model: int,
157
+ n_head: int,
158
+ mlp_ratio: float = 4.0,
159
+ ls_init_value: float = None,
160
+ act_layer: Callable = nn.GELU,
161
+ norm_layer: Callable = nn.LayerNorm,
162
+ drop_path: float = 0.0,
163
+ rope: Optional[nn.Module] = None,
164
+ ):
165
+ super().__init__()
166
+
167
+ if rope:
168
+ self.attn = SelfAttention(d_model, n_head, rope=rope)
169
+ else:
170
+ self.attn = nn.MultiheadAttention(d_model, n_head, batch_first=True)
171
+
172
+ self.ls_1 = (
173
+ LayerScale(d_model, ls_init_value)
174
+ if ls_init_value is not None
175
+ else nn.Identity()
176
+ )
177
+ self.ls_2 = (
178
+ LayerScale(d_model, ls_init_value)
179
+ if ls_init_value is not None
180
+ else nn.Identity()
181
+ )
182
+
183
+ self.ln_1 = norm_layer(d_model)
184
+ self.ln_2 = norm_layer(d_model)
185
+
186
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
187
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
188
+
189
+ mlp_width = int(d_model * mlp_ratio)
190
+ self.mlp = nn.Sequential(
191
+ OrderedDict(
192
+ [
193
+ ("c_fc", nn.Linear(d_model, mlp_width)),
194
+ ("gelu", act_layer()),
195
+ ("c_proj", nn.Linear(mlp_width, d_model)),
196
+ ]
197
+ )
198
+ )
199
+
200
+ def _call_attn(
201
+ self,
202
+ q_x: torch.Tensor,
203
+ attn_mask: Optional[torch.Tensor] = None,
204
+ ):
205
+
206
+ if attn_mask is not None:
207
+ # Leave boolean masks as is
208
+ if not attn_mask.dtype == torch.bool:
209
+ attn_mask = attn_mask.to(q_x.dtype)
210
+
211
+ if isinstance(self.attn, SelfAttention):
212
+ return self.attn(q_x, attn_mask=attn_mask)
213
+ else:
214
+ return self.attn(q_x, q_x, q_x, attn_mask=attn_mask, need_weights=False)[0]
215
+
216
+ def forward(
217
+ self,
218
+ x: torch.Tensor,
219
+ attn_mask: Optional[torch.Tensor] = None,
220
+ ):
221
+ x = x + self.drop_path1(
222
+ self.ls_1(self._call_attn(self.ln_1(x), attn_mask=attn_mask))
223
+ )
224
+ x = x + self.drop_path2(self.ls_2(self.mlp(self.ln_2(x))))
225
+ return x
226
+
227
+
228
+ class Transformer(nn.Module):
229
+ def __init__(
230
+ self,
231
+ width: int,
232
+ layers: int,
233
+ heads: int,
234
+ mlp_ratio: float = 4.0,
235
+ ls_init_value: float = None,
236
+ act_layer: Callable = nn.GELU,
237
+ norm_layer: Callable = nn.LayerNorm,
238
+ drop_path: float = 0.0,
239
+ rope: Optional[nn.Module] = None,
240
+ ):
241
+ super().__init__()
242
+ self.width = width
243
+ self.layers = layers
244
+ self.grad_checkpointing = False
245
+
246
+ self.resblocks = nn.ModuleList(
247
+ [
248
+ ResidualAttentionBlock(
249
+ width,
250
+ heads,
251
+ mlp_ratio,
252
+ ls_init_value=ls_init_value,
253
+ act_layer=act_layer,
254
+ norm_layer=norm_layer,
255
+ drop_path=drop_path,
256
+ rope=rope,
257
+ )
258
+ for _ in range(layers)
259
+ ]
260
+ )
261
+
262
+ @torch.jit.ignore
263
+ def set_grad_checkpointing(self, enable=True):
264
+ self.grad_checkpointing = enable
265
+
266
+ @torch.jit.ignore
267
+ def truncate(self, layer_idx: int):
268
+ """ Delete layers so the last layer is the given layer index. """
269
+ self.layers = ((self.layers + layer_idx) % self.layers) + 1
270
+ self.resblocks = nn.ModuleList(self.resblocks[:self.layers])
271
+
272
+ def forward(
273
+ self,
274
+ x: torch.Tensor,
275
+ attn_mask: Optional[torch.Tensor] = None,
276
+ layer_idx: int = -1,
277
+ ):
278
+ stop_idx = (self.layers + layer_idx) % self.layers
279
+
280
+ for i, r in enumerate(self.resblocks):
281
+ if self.grad_checkpointing and not torch.jit.is_scripting():
282
+ # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372
283
+ x = checkpoint(r, x, None, None, attn_mask)
284
+ else:
285
+ x = r(x, attn_mask=attn_mask)
286
+
287
+ if i == stop_idx:
288
+ break
289
+
290
+ return x
291
+
292
+
293
+ class VisionTransformer(nn.Module):
294
+ def __init__(
295
+ self,
296
+ patch_size: int,
297
+ width: int,
298
+ layers: int,
299
+ heads: int,
300
+ mlp_ratio: float,
301
+ act_layer: Callable = nn.GELU,
302
+ norm_layer: Callable = partial(nn.LayerNorm, eps=1e-5),
303
+ use_ln_pre: bool = True,
304
+ use_ln_post: bool = True,
305
+ ls_init_value: float = None,
306
+ drop_path: float = 0.0,
307
+ image_size: int = 448, # Pretrain image size only; you can pass in any image size
308
+ use_abs_posemb: bool = True,
309
+ use_rope2d: bool = True,
310
+ use_cls_token: bool = False,
311
+ output_dim: Optional[int] = 1280,
312
+ attn_pooler_heads: int = 8,
313
+ pool_type: Literal["attn", "tok", "avg", "none"] = "attn",
314
+ ):
315
+ super().__init__()
316
+ assert pool_type in ("attn", "tok", "avg", "none")
317
+ self.pool_type = pool_type
318
+ self.patch_size = patch_size
319
+
320
+ self.output_dim = output_dim or width
321
+ self.proj_dim = output_dim
322
+ self.heads = heads
323
+ self.width = width
324
+ self.layers = layers
325
+
326
+ self.use_abs_posemb = use_abs_posemb
327
+ self.use_cls_token = use_cls_token
328
+ self.use_rope2d = use_rope2d
329
+ self.image_size = image_size
330
+
331
+ self.conv1 = nn.Conv2d(
332
+ in_channels=3,
333
+ out_channels=width,
334
+ kernel_size=patch_size,
335
+ stride=patch_size,
336
+ bias=False,
337
+ )
338
+ self.rope = (
339
+ Rope2D(
340
+ dim=width // heads,
341
+ use_cls_token=self.use_cls_token,
342
+ )
343
+ if self.use_rope2d
344
+ else None
345
+ )
346
+
347
+ self.ln_pre = norm_layer(width) if use_ln_pre else nn.Identity()
348
+ self.ln_post = norm_layer(self.width) if use_ln_post else nn.Identity()
349
+
350
+ self.transformer = Transformer(
351
+ width,
352
+ layers,
353
+ heads,
354
+ mlp_ratio,
355
+ ls_init_value=ls_init_value,
356
+ act_layer=act_layer,
357
+ norm_layer=norm_layer,
358
+ drop_path=drop_path,
359
+ rope=self.rope,
360
+ )
361
+
362
+ if pool_type == "attn":
363
+ self.attn_pool = AttentionPooling(
364
+ embed_dim=width,
365
+ num_heads=attn_pooler_heads,
366
+ act_layer=act_layer,
367
+ norm_layer=norm_layer,
368
+ )
369
+ else:
370
+ self.attn_pool = None
371
+
372
+ self.init_tensors()
373
+
374
+
375
+ def init_tensors(self):
376
+ def init_submodule_tensors(module):
377
+ for name, child in module.named_children():
378
+ if hasattr(child, "init_tensors"):
379
+ logger.debug(f"Initializing tensors for submodule: {name}")
380
+ child.init_tensors()
381
+ init_submodule_tensors(child)
382
+
383
+ init_submodule_tensors(self)
384
+ self.rope.init_tensors()
385
+
386
+ # class embeddings and positional embeddings
387
+ init_scale = self.width**-0.5
388
+
389
+ if self.use_cls_token:
390
+ self.class_embedding = nn.Parameter(init_scale * torch.randn(self.width))
391
+
392
+ if self.use_abs_posemb:
393
+ self.posemb_grid_size = self.image_size // self.patch_size
394
+ self.positional_embedding = nn.Parameter(
395
+ init_scale
396
+ * torch.randn(
397
+ int(self.use_cls_token) + self.posemb_grid_size**2, self.width
398
+ )
399
+ )
400
+
401
+ if self.proj_dim is not None:
402
+ self.proj = nn.Parameter(
403
+ init_scale * torch.randn(self.width, self.proj_dim)
404
+ )
405
+
406
+
407
+ def load_ckpt(self, ckpt_path: str, verbose: bool = True):
408
+ _sd = torch.load(ckpt_path, weights_only=True)
409
+ if "state_dict" in _sd:
410
+ _sd = _sd["state_dict"]
411
+ elif "weights" in _sd:
412
+ _sd = _sd["weights"]
413
+
414
+ # for backwards compatibility
415
+ _sd = {k.replace("module.", ""): v for k, v in _sd.items()}
416
+ if any(k.startswith("visual.") for k in _sd):
417
+ _sd = {k.replace("visual.", ""): v for k, v in _sd.items() if "visual" in k}
418
+
419
+ m, u = self.load_state_dict(_sd, strict=False)
420
+
421
+ if verbose or (m or u):
422
+ logger.info(f"Missing keys for loading vision encoder: {m}")
423
+ logger.info(f"Unexpected keys for loading vision encoder: {u}")
424
+ print(f"Missing keys for loading vision encoder: {m}")
425
+ print(f"Unexpected keys for loading vision encoder: {u}")
426
+
427
+
428
+ def truncate(self, layer_idx: int):
429
+ """ Delete layers so the last layer is the given layer index. """
430
+ self.transformer.truncate(layer_idx)
431
+ self.layers = self.transformer.layers
432
+
433
+
434
+ @classmethod
435
+ def from_config(
436
+ cls,
437
+ name: str,
438
+ pretrained: bool = False,
439
+ checkpoint_path: Optional[str] = None,
440
+ **kwdargs
441
+ ):
442
+ if name not in PE_VISION_CONFIG:
443
+ raise RuntimeError(f"{name} not found in configs.")
444
+
445
+ args = asdict(PE_VISION_CONFIG[name])
446
+ args.update(kwdargs)
447
+
448
+ model = cls(**args)
449
+ if pretrained:
450
+ model.load_ckpt(fetch_pe_checkpoint(name, checkpoint_path))
451
+
452
+ return model
453
+
454
+ @classmethod
455
+ def available_configs(cls):
456
+ return list(PE_VISION_CONFIG.keys())
457
+
458
+
459
+ @torch.jit.ignore
460
+ def set_grad_checkpointing(self, enable=True):
461
+ self.transformer.set_grad_checkpointing(enable=enable)
462
+
463
+ def _sample_abs_posemb(self, grid_h: int, grid_w: int):
464
+ """Interpolates the absolute position embedding if necessary."""
465
+ if self.posemb_grid_size == grid_h and self.posemb_grid_size == grid_w:
466
+ return self.positional_embedding[None, ...]
467
+
468
+ pos_embed = self.positional_embedding
469
+ if self.use_cls_token:
470
+ cls_token_embed, pos_embed = pos_embed[:1], pos_embed[1:]
471
+
472
+ pos_embed = (
473
+ pos_embed.reshape(1, self.posemb_grid_size, self.posemb_grid_size, -1)
474
+ .permute(0, 3, 1, 2)
475
+ .contiguous()
476
+ )
477
+ pos_embed = F.interpolate(
478
+ pos_embed, size=(grid_h, grid_w), mode="bilinear", align_corners=False
479
+ )
480
+ pos_embed = pos_embed.permute(0, 2, 3, 1).reshape(-1, self.width).contiguous()
481
+
482
+ if self.use_cls_token:
483
+ pos_embed = torch.cat([cls_token_embed, pos_embed], dim=0)
484
+
485
+ return pos_embed[None, ...]
486
+
487
+ def _pool(self, x: torch.Tensor):
488
+ if self.pool_type == "tok":
489
+ return x[:, 0]
490
+ elif self.pool_type == "avg":
491
+ return x.mean(dim=1)
492
+ elif self.pool_type == "attn":
493
+ return self.attn_pool(x).squeeze(1)
494
+ elif self.pool_type == "none":
495
+ return x
496
+ else:
497
+ raise NotImplementedError
498
+
499
+ def forward_features(
500
+ self,
501
+ x: torch.Tensor,
502
+ norm: bool = False,
503
+ layer_idx: int = -1,
504
+ strip_cls_token: bool = False
505
+ ):
506
+ batch, _, h, w = x.shape
507
+ grid_h, grid_w = h // self.patch_size, w // self.patch_size
508
+
509
+ x = self.conv1(x)
510
+ x = x.permute(0, 2, 3, 1).reshape(batch, -1, self.width)
511
+
512
+ if self.use_cls_token:
513
+ x = torch.cat(
514
+ [self.class_embedding.view(1, 1, -1).expand(batch, -1, -1), x],
515
+ dim=1,
516
+ )
517
+
518
+ if self.use_abs_posemb:
519
+ x = x + self._sample_abs_posemb(grid_h, grid_w)
520
+
521
+ if self.use_rope2d:
522
+ self.rope.update_grid(x.device, grid_h, grid_w)
523
+
524
+ x = self.ln_pre(x)
525
+ x = self.transformer(x, layer_idx=layer_idx)
526
+
527
+ if norm:
528
+ x = self.ln_post(x)
529
+
530
+ if strip_cls_token and self.use_cls_token:
531
+ x = x[:, 1:, :]
532
+
533
+ return x
534
+
535
+ def forward(self, x: torch.Tensor, **kwargs):
536
+ x = self.forward_features(x, norm=True, **kwargs)
537
+ x = self._pool(x)
538
+
539
+ if self.proj_dim is not None:
540
+ x = x @ self.proj
541
+
542
+ return x
543
+
544
+
545
+
546
+
547
+
548
+
549
+
550
+
551
+
552
+ class TextTransformer(nn.Module):
553
+ def __init__(
554
+ self,
555
+ context_length: int = 72,
556
+ vocab_size: int = 49408,
557
+ width: int = 512,
558
+ heads: int = 8,
559
+ layers: int = 12,
560
+ mlp_ratio: float = 4.0,
561
+ ls_init_value: float = None,
562
+ output_dim: int = 1280,
563
+ no_causal_mask: bool = False,
564
+ pad_id: int = 0,
565
+ pool_type: str = "argmax",
566
+ proj_bias: bool = False,
567
+ act_layer: Callable = nn.GELU,
568
+ norm_layer: Callable = partial(nn.LayerNorm, eps=1e-5),
569
+ output_tokens: bool = False,
570
+ use_ln_post: bool = True,
571
+ ):
572
+ super().__init__()
573
+ assert pool_type in ("first", "last", "argmax", "none")
574
+ self.pool_type = pool_type
575
+ self.output_tokens = output_tokens
576
+ self.num_pos = self.context_length = context_length
577
+ self.vocab_size = vocab_size
578
+ self.width = width
579
+ self.output_dim = output_dim
580
+ self.heads = heads
581
+ self.pad_id = pad_id
582
+ self.layers = layers
583
+
584
+ self.token_embedding = nn.Embedding(vocab_size, width)
585
+ self.positional_embedding = nn.Parameter(torch.empty(self.num_pos, width))
586
+
587
+ self.transformer = Transformer(
588
+ width=width,
589
+ layers=layers,
590
+ heads=heads,
591
+ mlp_ratio=mlp_ratio,
592
+ ls_init_value=ls_init_value,
593
+ act_layer=act_layer,
594
+ norm_layer=norm_layer,
595
+ )
596
+
597
+ self.ln_final = norm_layer(width) if use_ln_post else nn.Identity()
598
+
599
+ if no_causal_mask:
600
+ self.attn_mask = None
601
+ else:
602
+ self.register_buffer(
603
+ "attn_mask", self.build_causal_mask(), persistent=False
604
+ )
605
+
606
+ if pool_type == "attn" or pool_type == "attn_eos":
607
+ self.attn_pool = AttentionPooling(
608
+ embed_dim=width,
609
+ num_heads=heads,
610
+ act_layer=act_layer,
611
+ norm_layer=norm_layer,
612
+ )
613
+ else: # argmax
614
+ self.attn_pool = None
615
+
616
+ if proj_bias:
617
+ self.text_projection = nn.Linear(width, output_dim)
618
+ else:
619
+ self.text_projection = nn.Parameter(torch.empty(width, output_dim))
620
+
621
+ def build_causal_mask(self):
622
+ # lazily create causal attention mask, with full attention between the tokens
623
+ # pytorch uses additive attention mask; fill with -inf
624
+ mask = torch.empty(self.num_pos, self.num_pos)
625
+ mask.fill_(float("-inf"))
626
+ mask.triu_(1) # zero out the lower diagonal
627
+ return mask
628
+
629
+ def load_ckpt(self, ckpt_path: str, verbose: bool = True):
630
+ _sd = torch.load(ckpt_path, weights_only=True)
631
+ if "state_dict" in _sd:
632
+ _sd = _sd["state_dict"]
633
+ elif "weights" in _sd:
634
+ _sd = _sd["weights"]
635
+
636
+ _sd = {k.replace("module.", ""): v for k, v in _sd.items()}
637
+
638
+ m, u = self.load_state_dict(_sd, strict=False)
639
+
640
+ if verbose or (m or u):
641
+ logger.info(f"Missing keys for loading model: {m}")
642
+ logger.info(f"Unexpected keys for loading model: {u}")
643
+ print(f"Missing keys for loading model: {m}")
644
+ print(f"Unexpected keys for loading model: {u}")
645
+
646
+ def build_cls_mask(self, text):
647
+ cls_mask = (text != self.pad_id).unsqueeze(1)
648
+ cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=True)
649
+ additive_mask = torch.empty(cls_mask.shape, device=cls_mask.device)
650
+ additive_mask.fill_(0)
651
+ additive_mask.masked_fill_(~cls_mask, float("-inf"))
652
+ additive_mask = torch.repeat_interleave(additive_mask, self.heads, 0)
653
+ return additive_mask
654
+
655
+ def text_global_pool(
656
+ self, x, text: Optional[torch.Tensor] = None, pool_type: str = "argmax"
657
+ ):
658
+ if pool_type == "first":
659
+ pooled, tokens = x[:, 0], x[:, 1:]
660
+ elif pool_type == "last":
661
+ pooled, tokens = x[:, -1], x[:, :-1]
662
+ elif pool_type == "argmax":
663
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
664
+ assert text is not None
665
+ pooled, tokens = x[torch.arange(x.shape[0]), text.argmax(dim=-1)], x
666
+ else:
667
+ pooled = tokens = x
668
+
669
+ return pooled, tokens
670
+
671
+ def forward(self, text):
672
+ seq_len = text.shape[1]
673
+ x = self.token_embedding(
674
+ text
675
+ )
676
+ attn_mask = self.attn_mask
677
+ if attn_mask is not None:
678
+ attn_mask = attn_mask[:seq_len, :seq_len]
679
+
680
+ x = x + self.positional_embedding[:seq_len]
681
+ x = self.transformer(x, attn_mask=attn_mask)
682
+
683
+ x = self.ln_final(x)
684
+ pooled, tokens = self.text_global_pool(x, text, pool_type=self.pool_type)
685
+
686
+ if self.text_projection is not None:
687
+ if isinstance(self.text_projection, nn.Linear):
688
+ pooled = self.text_projection(pooled)
689
+ else:
690
+ pooled = pooled @ self.text_projection
691
+
692
+ if self.output_tokens:
693
+ return pooled, tokens
694
+
695
+ return pooled
696
+
697
+
698
+
699
+
700
+ class CLIP(TextTransformer):
701
+ def __init__(
702
+ self,
703
+ vision_cfg: PEConfig,
704
+ text_cfg: PETextConfig,
705
+ init_logit_scale: float = np.log(1 / 0.07)
706
+ ):
707
+ super(CLIP, self).__init__(**asdict(text_cfg))
708
+ self.visual = VisionTransformer(**asdict(vision_cfg))
709
+ self.image_size = self.visual.image_size # For ease of use
710
+ self.logit_scale = nn.Parameter(torch.ones([]) * init_logit_scale)
711
+
712
+
713
+ def encode_image(self, image, normalize: bool = False):
714
+ x = self.visual(image)
715
+ return F.normalize(x, dim=-1) if normalize else x
716
+
717
+ def encode_video(self, video, normalize: bool = False): # b n c h w
718
+ b, n, c, h, w = video.shape
719
+ frms = video.reshape(b * n, c, h, w)
720
+ frm_feats = self.encode_image(frms, normalize=normalize)
721
+ video_feats = frm_feats.reshape(b, n, -1)
722
+ video_feats = video_feats.mean(dim=1)
723
+ return video_feats
724
+
725
+ def encode_text(self, text, normalize: bool = False):
726
+ x = super().forward(text)
727
+ return F.normalize(x, dim=-1) if normalize else x
728
+
729
+ def forward(
730
+ self,
731
+ image: Optional[torch.Tensor] = None,
732
+ text: Optional[torch.Tensor] = None,
733
+ ):
734
+ image_features = (
735
+ self.encode_image(image, normalize=True) if image is not None else None
736
+ )
737
+ text_features = (
738
+ self.encode_text(text, normalize=True) if text is not None else None
739
+ )
740
+ return image_features, text_features, self.logit_scale.exp()
741
+
742
+
743
+ @classmethod
744
+ def from_config(
745
+ cls,
746
+ name: str,
747
+ pretrained: bool = False,
748
+ checkpoint_path: Optional[str] = None # To load your own
749
+ ):
750
+ if name not in PE_VISION_CONFIG or name not in PE_TEXT_CONFIG:
751
+ raise RuntimeError(f"{name} not found in configs.")
752
+
753
+ model = cls(PE_VISION_CONFIG[name], PE_TEXT_CONFIG[name])
754
+ if pretrained:
755
+ model.load_ckpt(fetch_pe_checkpoint(name, checkpoint_path))
756
+
757
+ return model
758
+
759
+ @classmethod
760
+ def available_configs(cls):
761
+ return [k for k in PE_VISION_CONFIG if k in PE_TEXT_CONFIG]
core/vision_encoder/rope.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from math import log, pi
2
+ from typing import Literal, Optional, Union
3
+
4
+ import torch
5
+ from einops import rearrange, repeat
6
+ from torch import Tensor, broadcast_tensors, einsum, nn
7
+ from torch.amp import autocast
8
+ from torch.nn import Module, ModuleList
9
+
10
+ # helper functions
11
+
12
+
13
+ def exists(val):
14
+ return val is not None
15
+
16
+
17
+ def default(val, d):
18
+ return val if exists(val) else d
19
+
20
+
21
+ # broadcat, as tortoise-tts was using it
22
+
23
+
24
+ def broadcat(tensors, dim=-1):
25
+ broadcasted_tensors = broadcast_tensors(*tensors)
26
+ return torch.cat(broadcasted_tensors, dim=dim)
27
+
28
+
29
+ # rotary embedding helper functions
30
+
31
+
32
+ def rotate_half(x):
33
+ x = rearrange(x, "... (d r) -> ... d r", r=2)
34
+ x1, x2 = x.unbind(dim=-1)
35
+ x = torch.stack((-x2, x1), dim=-1)
36
+ return rearrange(x, "... d r -> ... (d r)")
37
+
38
+
39
+ @autocast("cuda", enabled=False)
40
+ def apply_rotary_emb(freqs, t, start_index=0, scale=1.0, seq_dim=-2):
41
+ dtype = t.dtype
42
+
43
+ if t.ndim == 3:
44
+ seq_len = t.shape[seq_dim]
45
+ freqs = freqs[-seq_len:]
46
+
47
+ rot_dim = freqs.shape[-1]
48
+ end_index = start_index + rot_dim
49
+
50
+ assert (
51
+ rot_dim <= t.shape[-1]
52
+ ), f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}"
53
+
54
+ t_left, t, t_right = (
55
+ t[..., :start_index],
56
+ t[..., start_index:end_index],
57
+ t[..., end_index:],
58
+ )
59
+ t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
60
+ out = torch.cat((t_left, t, t_right), dim=-1)
61
+
62
+ return out.type(dtype)
63
+
64
+
65
+ # learned rotation helpers
66
+
67
+
68
+ def apply_learned_rotations(rotations, t, start_index=0, freq_ranges=None):
69
+ if exists(freq_ranges):
70
+ rotations = einsum("..., f -> ... f", rotations, freq_ranges)
71
+ rotations = rearrange(rotations, "... r f -> ... (r f)")
72
+
73
+ rotations = repeat(rotations, "... n -> ... (n r)", r=2)
74
+ return apply_rotary_emb(rotations, t, start_index=start_index)
75
+
76
+
77
+ # classes
78
+
79
+
80
+ class RotaryEmbedding(Module):
81
+ def __init__(
82
+ self,
83
+ dim,
84
+ custom_freqs: Optional[Tensor] = None,
85
+ freqs_for: Union[
86
+ Literal["lang"], Literal["pixel"], Literal["constant"]
87
+ ] = "lang",
88
+ theta=10000,
89
+ max_freq=10,
90
+ num_freqs=1,
91
+ learned_freq=False,
92
+ use_xpos=False,
93
+ xpos_scale_base=512,
94
+ interpolate_factor=1.0,
95
+ theta_rescale_factor=1.0,
96
+ seq_before_head_dim=False,
97
+ cache_if_possible=True,
98
+ ):
99
+ super().__init__()
100
+ # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
101
+ # has some connection to NTK literature
102
+ # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
103
+
104
+ theta *= theta_rescale_factor ** (dim / (dim - 2))
105
+
106
+ self.freqs_for = freqs_for
107
+
108
+ if exists(custom_freqs):
109
+ freqs = custom_freqs
110
+ elif freqs_for == "lang":
111
+ freqs = 1.0 / (
112
+ theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
113
+ )
114
+ elif freqs_for == "pixel":
115
+ freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
116
+ elif freqs_for == "constant":
117
+ freqs = torch.ones(num_freqs).float()
118
+
119
+ self.cache_if_possible = cache_if_possible
120
+
121
+ self.tmp_store("cached_freqs", None)
122
+ self.tmp_store("cached_scales", None)
123
+
124
+ self.freqs = nn.Parameter(freqs, requires_grad=learned_freq)
125
+
126
+ self.learned_freq = learned_freq
127
+
128
+ # dummy for device
129
+
130
+ self.tmp_store("dummy", torch.tensor(0))
131
+
132
+ # default sequence dimension
133
+
134
+ self.seq_before_head_dim = seq_before_head_dim
135
+ self.default_seq_dim = -3 if seq_before_head_dim else -2
136
+
137
+ # interpolation factors
138
+
139
+ assert interpolate_factor >= 1.0
140
+ self.interpolate_factor = interpolate_factor
141
+
142
+ # xpos
143
+
144
+ self.use_xpos = use_xpos
145
+ if not use_xpos:
146
+ self.tmp_store("scale", None)
147
+ return
148
+
149
+ scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
150
+
151
+ self.scale_base = xpos_scale_base
152
+ self.tmp_store("scale", scale)
153
+
154
+ # add apply_rotary_emb as static method
155
+
156
+ self.apply_rotary_emb = staticmethod(apply_rotary_emb)
157
+
158
+ @property
159
+ def device(self):
160
+ return self.dummy.device
161
+
162
+ def tmp_store(self, key, value):
163
+ self.register_buffer(key, value, persistent=False)
164
+
165
+ def get_seq_pos(self, seq_len, device, dtype, offset=0):
166
+ return (
167
+ torch.arange(seq_len, device=device, dtype=dtype) + offset
168
+ ) / self.interpolate_factor
169
+
170
+ def rotate_queries_or_keys(self, t, seq_dim=None, offset=0):
171
+ seq_dim = default(seq_dim, self.default_seq_dim)
172
+
173
+ assert (
174
+ not self.use_xpos
175
+ ), "you must use `.rotate_queries_and_keys` method instead and pass in both queries and keys, for length extrapolatable rotary embeddings"
176
+
177
+ device, dtype, seq_len = t.device, t.dtype, t.shape[seq_dim]
178
+
179
+ freqs = self.forward(
180
+ self.get_seq_pos(seq_len, device=device, dtype=dtype, offset=offset),
181
+ seq_len=seq_len,
182
+ offset=offset,
183
+ )
184
+
185
+ if seq_dim == -3:
186
+ freqs = rearrange(freqs, "n d -> n 1 d")
187
+
188
+ return apply_rotary_emb(freqs, t, seq_dim=seq_dim)
189
+
190
+ def rotate_queries_with_cached_keys(self, q, k, seq_dim=None, offset=0):
191
+ seq_dim = default(seq_dim, self.default_seq_dim)
192
+
193
+ q_len, k_len = q.shape[seq_dim], k.shape[seq_dim]
194
+ assert q_len <= k_len
195
+
196
+ rotated_q = self.rotate_queries_or_keys(
197
+ q, seq_dim=seq_dim, offset=k_len - q_len + offset
198
+ )
199
+ rotated_k = self.rotate_queries_or_keys(k, seq_dim=seq_dim, offset=offset)
200
+
201
+ rotated_q = rotated_q.type(q.dtype)
202
+ rotated_k = rotated_k.type(k.dtype)
203
+
204
+ return rotated_q, rotated_k
205
+
206
+ def rotate_queries_and_keys(self, q, k, seq_dim=None):
207
+ seq_dim = default(seq_dim, self.default_seq_dim)
208
+
209
+ assert self.use_xpos
210
+ device, dtype, seq_len = q.device, q.dtype, q.shape[seq_dim]
211
+
212
+ seq = self.get_seq_pos(seq_len, dtype=dtype, device=device)
213
+
214
+ freqs = self.forward(seq, seq_len=seq_len)
215
+ scale = self.get_scale(seq, seq_len=seq_len).to(dtype)
216
+
217
+ if seq_dim == -3:
218
+ freqs = rearrange(freqs, "n d -> n 1 d")
219
+ scale = rearrange(scale, "n d -> n 1 d")
220
+
221
+ rotated_q = apply_rotary_emb(freqs, q, scale=scale, seq_dim=seq_dim)
222
+ rotated_k = apply_rotary_emb(freqs, k, scale=scale**-1, seq_dim=seq_dim)
223
+
224
+ rotated_q = rotated_q.type(q.dtype)
225
+ rotated_k = rotated_k.type(k.dtype)
226
+
227
+ return rotated_q, rotated_k
228
+
229
+ def get_scale(self, t: Tensor, seq_len: Optional[int] = None, offset=0):
230
+ assert self.use_xpos
231
+
232
+ should_cache = self.cache_if_possible and exists(seq_len)
233
+
234
+ if (
235
+ should_cache
236
+ and exists(self.cached_scales)
237
+ and (seq_len + offset) <= self.cached_scales.shape[0]
238
+ ):
239
+ return self.cached_scales[offset : (offset + seq_len)]
240
+
241
+ scale = 1.0
242
+ if self.use_xpos:
243
+ power = (t - len(t) // 2) / self.scale_base
244
+ scale = self.scale ** rearrange(power, "n -> n 1")
245
+ scale = torch.cat((scale, scale), dim=-1)
246
+
247
+ if should_cache:
248
+ self.tmp_store("cached_scales", scale)
249
+
250
+ return scale
251
+
252
+ def get_axial_freqs(self, *dims):
253
+ Colon = slice(None)
254
+ all_freqs = []
255
+
256
+ for ind, dim in enumerate(dims):
257
+ if self.freqs_for == "pixel":
258
+ pos = torch.linspace(-1, 1, steps=dim, device=self.device)
259
+ else:
260
+ pos = torch.arange(dim, device=self.device)
261
+
262
+ freqs = self.forward(pos, seq_len=dim)
263
+
264
+ all_axis = [None] * len(dims)
265
+ all_axis[ind] = Colon
266
+
267
+ new_axis_slice = (Ellipsis, *all_axis, Colon)
268
+ all_freqs.append(freqs[new_axis_slice])
269
+
270
+ all_freqs = broadcast_tensors(*all_freqs)
271
+ return torch.cat(all_freqs, dim=-1)
272
+
273
+ @autocast("cuda", enabled=False)
274
+ def forward(self, t: Tensor, seq_len=None, offset=0):
275
+ should_cache = (
276
+ self.cache_if_possible
277
+ and not self.learned_freq
278
+ and exists(seq_len)
279
+ and self.freqs_for != "pixel"
280
+ )
281
+
282
+ if (
283
+ should_cache
284
+ and exists(self.cached_freqs)
285
+ and (offset + seq_len) <= self.cached_freqs.shape[0]
286
+ ):
287
+ return self.cached_freqs[offset : (offset + seq_len)].detach()
288
+
289
+ freqs = self.freqs
290
+
291
+ freqs = einsum("..., f -> ... f", t.type(freqs.dtype), freqs)
292
+ freqs = repeat(freqs, "... n -> ... (n r)", r=2)
293
+
294
+ if should_cache:
295
+ self.tmp_store("cached_freqs", freqs.detach())
296
+
297
+ return freqs
298
+
299
+
300
+
301
+
302
+
303
+ class Rope2D:
304
+ """ Helper class to apply RoPE2D as well as interpolate on the fly. """
305
+
306
+ def __init__(self, dim, use_cls_token=False):
307
+ self.dim = dim
308
+ self.use_cls_token = use_cls_token
309
+ self.grid_size = None
310
+ self.freq = None
311
+
312
+ def init_tensors(self):
313
+ self.rope = RotaryEmbedding(self.dim // 2)
314
+
315
+ def update_grid(self, device, grid_h, grid_w):
316
+ if self.grid_size != (grid_h, grid_w):
317
+ self.grid_size = (grid_h, grid_w)
318
+
319
+ self.rope = self.rope.to(device)
320
+
321
+ if self.use_cls_token:
322
+ # +1 to leave space for the cls token to be (0, 0)
323
+ grid_y_range = torch.arange(grid_h, device=device) + 1
324
+ grid_x_range = torch.arange(grid_w, device=device) + 1
325
+ else:
326
+ grid_y_range = torch.arange(grid_h, device=device)
327
+ grid_x_range = torch.arange(grid_w, device=device)
328
+
329
+ freqs_y = self.rope(grid_y_range)[:, None].expand(grid_h, grid_w, -1)
330
+ freqs_x = self.rope(grid_x_range)[None, :].expand(grid_h, grid_w, -1)
331
+ freq = torch.cat([freqs_x, freqs_y], dim=-1).reshape(grid_h * grid_w, -1)
332
+
333
+ if self.use_cls_token:
334
+ freq = torch.cat(
335
+ [torch.zeros(1, freq.shape[-1], device=device), freq], dim=0
336
+ )
337
+
338
+ self.freq = freq[None, ...]
339
+
340
+ self.freq = self.freq.to(device)
341
+
342
+ def __call__(self, q, k):
343
+ # batch, heads, seq, dim = q.shape
344
+ q = apply_rotary_emb(self.freq[:, None, :, :], q)
345
+ k = apply_rotary_emb(self.freq[:, None, :, :], k)
346
+
347
+ return q, k
core/vision_encoder/tokenizer.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ CLIP tokenizer
2
+
3
+ Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4
+ """
5
+
6
+ import gzip
7
+ import html
8
+ import os
9
+ import random
10
+ import string
11
+ from functools import lru_cache, partial
12
+ from typing import Callable, List, Optional, Union
13
+
14
+ import ftfy
15
+ import regex as re
16
+ import torch
17
+
18
+ # https://stackoverflow.com/q/62691279
19
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
20
+
21
+ DEFAULT_CONTEXT_LENGTH = 77 # default context length for OpenAI CLIP
22
+
23
+
24
+ @lru_cache()
25
+ def default_bpe():
26
+ return os.path.join(
27
+ os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz"
28
+ )
29
+
30
+
31
+ @lru_cache()
32
+ def bytes_to_unicode():
33
+ """
34
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
35
+ The reversible bpe codes work on unicode strings.
36
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
37
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
38
+ This is a significant percentage of your normal, say, 32K bpe vocab.
39
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
40
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
41
+ """
42
+ bs = (
43
+ list(range(ord("!"), ord("~") + 1))
44
+ + list(range(ord("¡"), ord("¬") + 1))
45
+ + list(range(ord("®"), ord("ÿ") + 1))
46
+ )
47
+ cs = bs[:]
48
+ n = 0
49
+ for b in range(2**8):
50
+ if b not in bs:
51
+ bs.append(b)
52
+ cs.append(2**8 + n)
53
+ n += 1
54
+ cs = [chr(n) for n in cs]
55
+ return dict(zip(bs, cs))
56
+
57
+
58
+ def get_pairs(word):
59
+ """Return set of symbol pairs in a word.
60
+ Word is represented as tuple of symbols (symbols being variable-length strings).
61
+ """
62
+ pairs = set()
63
+ prev_char = word[0]
64
+ for char in word[1:]:
65
+ pairs.add((prev_char, char))
66
+ prev_char = char
67
+ return pairs
68
+
69
+
70
+ def basic_clean(text):
71
+ text = ftfy.fix_text(text)
72
+ text = html.unescape(html.unescape(text))
73
+ return text.strip()
74
+
75
+
76
+ def whitespace_clean(text):
77
+ text = re.sub(r"\s+", " ", text)
78
+ text = text.strip()
79
+ return text
80
+
81
+
82
+ def _clean_canonicalize(x):
83
+ # basic, remove whitespace, remove punctuation, lower case
84
+ return canonicalize_text(basic_clean(x))
85
+
86
+
87
+ def _clean_lower(x):
88
+ # basic, remove whitespace, lower case
89
+ return whitespace_clean(basic_clean(x)).lower()
90
+
91
+
92
+ def _clean_whitespace(x):
93
+ # basic, remove whitespace
94
+ return whitespace_clean(basic_clean(x))
95
+
96
+
97
+ def get_clean_fn(type: str):
98
+ if type == "canonicalize":
99
+ return _clean_canonicalize
100
+ elif type == "lower":
101
+ return _clean_lower
102
+ elif type == "whitespace":
103
+ return _clean_whitespace
104
+ else:
105
+ assert False, f"Invalid clean function ({type})."
106
+
107
+
108
+ def canonicalize_text(text, *, keep_punctuation_exact_string=None):
109
+ """Returns canonicalized `text` (lowercase and punctuation removed).
110
+
111
+ From: https://github.com/google-research/big_vision/blob/53f18caf27a9419231bbf08d3388b07671616d3d/big_vision/evaluators/proj/image_text/prompt_engineering.py#L94
112
+
113
+ Args:
114
+ text: string to be canonicalized.
115
+ keep_punctuation_exact_string: If provided, then this exact string kept.
116
+ For example providing '{}' will keep any occurrences of '{}' (but will
117
+ still remove '{' and '}' that appear separately).
118
+ """
119
+ text = text.replace("_", " ")
120
+ if keep_punctuation_exact_string:
121
+ text = keep_punctuation_exact_string.join(
122
+ part.translate(str.maketrans("", "", string.punctuation))
123
+ for part in text.split(keep_punctuation_exact_string)
124
+ )
125
+ else:
126
+ text = text.translate(str.maketrans("", "", string.punctuation))
127
+ text = text.lower()
128
+ text = re.sub(r"\s+", " ", text)
129
+ return text.strip()
130
+
131
+
132
+ class SimpleTokenizer(object):
133
+ def __init__(
134
+ self,
135
+ bpe_path: str = default_bpe(),
136
+ additional_special_tokens: Optional[List[str]] = None,
137
+ context_length: Optional[int] = DEFAULT_CONTEXT_LENGTH,
138
+ clean: str = "lower",
139
+ reduction_mask: str = "",
140
+ ):
141
+ self.byte_encoder = bytes_to_unicode()
142
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
143
+ merges = gzip.open(bpe_path).read().decode("utf-8").split("\n")
144
+ merges = merges[1 : 49152 - 256 - 2 + 1]
145
+ merges = [tuple(merge.split()) for merge in merges]
146
+ vocab = list(bytes_to_unicode().values())
147
+ vocab = vocab + [v + "</w>" for v in vocab]
148
+ for merge in merges:
149
+ vocab.append("".join(merge))
150
+ special_tokens = ["<start_of_text>", "<end_of_text>"]
151
+ if additional_special_tokens:
152
+ special_tokens += additional_special_tokens
153
+ vocab.extend(special_tokens)
154
+ self.encoder = dict(zip(vocab, range(len(vocab))))
155
+ self.decoder = {v: k for k, v in self.encoder.items()}
156
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
157
+ self.cache = {t: t for t in special_tokens}
158
+ special = "|".join(special_tokens)
159
+ self.pat = re.compile(
160
+ special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
161
+ re.IGNORECASE,
162
+ )
163
+ self.vocab_size = len(self.encoder)
164
+ self.all_special_ids = [self.encoder[t] for t in special_tokens]
165
+ self.sot_token_id = self.all_special_ids[0]
166
+ self.eot_token_id = self.all_special_ids[1]
167
+ self.context_length = context_length
168
+ self.clean_fn = get_clean_fn(clean)
169
+ self.reduction_fn = (
170
+ get_reduction_mask_fn(reduction_mask) if reduction_mask else None
171
+ )
172
+
173
+ def bpe(self, token):
174
+ if token in self.cache:
175
+ return self.cache[token]
176
+ word = tuple(token[:-1]) + (token[-1] + "</w>",)
177
+ pairs = get_pairs(word)
178
+
179
+ if not pairs:
180
+ return token + "</w>"
181
+
182
+ while True:
183
+ bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
184
+ if bigram not in self.bpe_ranks:
185
+ break
186
+ first, second = bigram
187
+ new_word = []
188
+ i = 0
189
+ while i < len(word):
190
+ try:
191
+ j = word.index(first, i)
192
+ new_word.extend(word[i:j])
193
+ i = j
194
+ except:
195
+ new_word.extend(word[i:])
196
+ break
197
+
198
+ if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
199
+ new_word.append(first + second)
200
+ i += 2
201
+ else:
202
+ new_word.append(word[i])
203
+ i += 1
204
+ new_word = tuple(new_word)
205
+ word = new_word
206
+ if len(word) == 1:
207
+ break
208
+ else:
209
+ pairs = get_pairs(word)
210
+ word = " ".join(word)
211
+ self.cache[token] = word
212
+ return word
213
+
214
+ def encode(self, text):
215
+ bpe_tokens = []
216
+ text = self.clean_fn(text)
217
+ for token in re.findall(self.pat, text):
218
+ token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
219
+ bpe_tokens.extend(
220
+ self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ")
221
+ )
222
+ return bpe_tokens
223
+
224
+ def decode(self, tokens):
225
+ text = "".join([self.decoder[token] for token in tokens])
226
+ text = (
227
+ bytearray([self.byte_decoder[c] for c in text])
228
+ .decode("utf-8", errors="replace")
229
+ .replace("</w>", " ")
230
+ )
231
+ return text
232
+
233
+ def __call__(
234
+ self, texts: Union[str, List[str]], context_length: Optional[int] = None
235
+ ) -> torch.LongTensor:
236
+ """Returns the tokenized representation of given input string(s)
237
+
238
+ Parameters
239
+ ----------
240
+ texts : Union[str, List[str]]
241
+ An input string or a list of input strings to tokenize
242
+ context_length : int
243
+ The context length to use; all CLIP models use 77 as the context length
244
+
245
+ Returns
246
+ -------
247
+ A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
248
+ """
249
+ if isinstance(texts, str):
250
+ texts = [texts]
251
+
252
+ context_length = context_length or self.context_length
253
+ assert context_length, "Please set a valid context length"
254
+
255
+ if self.reduction_fn is not None:
256
+ # use reduction strategy for tokenize if set, otherwise default to truncation below
257
+ return self.reduction_fn(
258
+ texts,
259
+ context_length=context_length,
260
+ sot_token_id=self.sot_token_id,
261
+ eot_token_id=self.eot_token_id,
262
+ encode_fn=self.encode,
263
+ )
264
+
265
+ all_tokens = [
266
+ [self.sot_token_id] + self.encode(text) + [self.eot_token_id]
267
+ for text in texts
268
+ ]
269
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
270
+
271
+ for i, tokens in enumerate(all_tokens):
272
+ if len(tokens) > context_length:
273
+ tokens = tokens[:context_length] # Truncate
274
+ tokens[-1] = self.eot_token_id
275
+ result[i, : len(tokens)] = torch.tensor(tokens)
276
+
277
+ return result
278
+
279
+
280
+ def random_mask_tokenize(
281
+ texts: Union[str, List[str]],
282
+ context_length: int,
283
+ sot_token_id: int,
284
+ eot_token_id: int,
285
+ encode_fn: Callable,
286
+ shuffle: bool = False,
287
+ ):
288
+ all_tokens = [encode_fn(text) for text in texts]
289
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
290
+
291
+ for i, tokens in enumerate(all_tokens):
292
+ tokens = torch.tensor(tokens)
293
+ num_tokens = len(tokens)
294
+ if num_tokens > context_length - 2: # 2 for sot and eot token
295
+ num_keep = context_length - 2
296
+ indices = torch.randperm(len(tokens))
297
+ indices = indices[:num_keep]
298
+ if not shuffle:
299
+ indices = indices.msort()
300
+ tokens = tokens[indices]
301
+ num_tokens = num_keep
302
+ result[i, 0] = sot_token_id
303
+ result[i, 1 : num_tokens + 1] = tokens
304
+ result[i, num_tokens + 1] = eot_token_id
305
+
306
+ return result
307
+
308
+
309
+ def simple_mask_tokenize(
310
+ texts: Union[str, List[str]],
311
+ context_length: int,
312
+ sot_token_id: int,
313
+ eot_token_id: int,
314
+ encode_fn: Callable,
315
+ ):
316
+ all_tokens = [encode_fn(text) for text in texts]
317
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
318
+
319
+ for i, tokens in enumerate(all_tokens):
320
+ num_tokens = len(tokens)
321
+ if num_tokens > context_length - 2: # 2 for sot and eot token
322
+ num_keep = context_length - 2
323
+ start_index = random.randint(0, num_tokens - num_keep) # high is incl
324
+ tokens = tokens[start_index : start_index + num_keep]
325
+ tokens = [sot_token_id] + tokens + [eot_token_id]
326
+ result[i, : len(tokens)] = torch.tensor(tokens)
327
+
328
+ return result
329
+
330
+
331
+
332
+ def get_reduction_mask_fn(type: str):
333
+ """Choose strategy for dropping (masking) tokens to achieve target context length"""
334
+ assert type in ("simple", "random", "shuffle")
335
+ if type == "simple":
336
+ return simple_mask_tokenize # randomly select block [start:end]
337
+ elif type == "random":
338
+ return random_mask_tokenize # randomly drop tokens (keep order)
339
+ elif type == "shuffle":
340
+ return partial(
341
+ random_mask_tokenize, shuffle=True
342
+ ) # randomly drop tokens (shuffle order)
core/vision_encoder/transforms.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchvision.transforms as T
2
+
3
+ from core.vision_encoder.tokenizer import SimpleTokenizer
4
+
5
+
6
+ def get_image_transform(
7
+ image_size: int,
8
+ center_crop: bool = False,
9
+ interpolation: T.InterpolationMode = T.InterpolationMode.BILINEAR # We used bilinear during training
10
+ ):
11
+ if center_crop:
12
+ crop = [
13
+ T.Resize(image_size, interpolation=interpolation),
14
+ T.CenterCrop(image_size)
15
+ ]
16
+ else:
17
+ # "Squash": most versatile
18
+ crop = [
19
+ T.Resize((image_size, image_size), interpolation=interpolation)
20
+ ]
21
+
22
+ return T.Compose(crop + [
23
+ T.Lambda(lambda x: x.convert("RGB")),
24
+ T.ToTensor(),
25
+ T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5], inplace=True),
26
+ ])
27
+
28
+
29
+
30
+ def get_text_tokenizer(context_length: int):
31
+ return SimpleTokenizer(context_length=context_length)
models.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Minimal model-loading code for the 7 VFM baselines in the paper."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import sys
6
+ from pathlib import Path
7
+ from typing import Callable
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from torchvision import transforms
12
+ from transformers import AutoConfig, AutoImageProcessor, AutoModel
13
+
14
+ ROOT = Path(__file__).resolve().parent
15
+ WEIGHTS_DIR = ROOT / "weights"
16
+
17
+ MODEL_SPECS = {
18
+ "metacliplin": {
19
+ "paper_name": "MetaCLIP-Linear",
20
+ "checkpoint": "metacliplin0.pth",
21
+ "hf_model": "facebook/metaclip-h14-fullcc2.5b",
22
+ "feature_dim": 1280,
23
+ "image_size": 224,
24
+ "pooler_output": True,
25
+ },
26
+ "metaclip2lin": {
27
+ "paper_name": "MetaCLIP2-Linear",
28
+ "checkpoint": "metaclip2lin0.pth",
29
+ "hf_model": "facebook/metaclip-2-worldwide-giant",
30
+ "feature_dim": 1280,
31
+ "image_size": 224,
32
+ "pooler_output": True,
33
+ },
34
+ "sigliplin": {
35
+ "paper_name": "SigLIP-Linear",
36
+ "checkpoint": "sigliplin0.pth",
37
+ "hf_model": "google/siglip-large-patch16-384",
38
+ "feature_dim": 1024,
39
+ "image_size": 384,
40
+ "pooler_output": True,
41
+ },
42
+ "siglip2lin": {
43
+ "paper_name": "SigLIP2-Linear",
44
+ "checkpoint": "siglip2lin0.pth",
45
+ "hf_model": "google/siglip2-giant-opt-patch16-384",
46
+ "feature_dim": 1536,
47
+ "image_size": 384,
48
+ "pooler_output": True,
49
+ },
50
+ "pelin": {
51
+ "paper_name": "PE-CLIP-Linear",
52
+ "checkpoint": "pelin0.pth",
53
+ "feature_dim": 1024,
54
+ "image_size": 336,
55
+ "pooler_output": False,
56
+ },
57
+ "dinov2lin": {
58
+ "paper_name": "DINOv2-Linear",
59
+ "checkpoint": "dinov2lin0.pth",
60
+ "feature_dim": 1024,
61
+ "pooler_output": False,
62
+ },
63
+ "dinov3lin": {
64
+ "paper_name": "DINOv3-Linear",
65
+ "checkpoint": "dinov3lin0.pth",
66
+ "hf_model": "facebook/dinov3-vit7b16-pretrain-lvd1689m",
67
+ "feature_dim": 4096,
68
+ "pooler_output": False,
69
+ },
70
+ }
71
+
72
+ ALIASES = {
73
+ "MetaCLIP-Linear": "metacliplin",
74
+ "MetaCLIP2-Linear": "metaclip2lin",
75
+ "SigLIP-Linear": "sigliplin",
76
+ "SigLIP2-Linear": "siglip2lin",
77
+ "PE-CLIP-Linear": "pelin",
78
+ "DINOv2-Linear": "dinov2lin",
79
+ "DINOv3-Linear": "dinov3lin",
80
+ }
81
+
82
+
83
+ def canonical_model_name(name: str) -> str:
84
+ if name in MODEL_SPECS:
85
+ return name
86
+ if name in ALIASES:
87
+ return ALIASES[name]
88
+ raise KeyError(f"Unknown model: {name}")
89
+
90
+
91
+ def default_checkpoint_path(model_name: str) -> Path:
92
+ model_name = canonical_model_name(model_name)
93
+ return WEIGHTS_DIR / MODEL_SPECS[model_name]["checkpoint"]
94
+
95
+
96
+ def _resolve_device(device: str | torch.device | None = None) -> torch.device:
97
+ if device is None:
98
+ return torch.device("cuda" if torch.cuda.is_available() else "cpu")
99
+ return torch.device(device)
100
+
101
+
102
+ def _load_checkpoint(checkpoint_path: str | Path) -> dict:
103
+ checkpoint = torch.load(str(checkpoint_path), map_location="cpu", weights_only=False)
104
+ if isinstance(checkpoint, dict):
105
+ for key in ("state_dict", "model", "model_state_dict"):
106
+ if key in checkpoint and isinstance(checkpoint[key], dict):
107
+ checkpoint = checkpoint[key]
108
+ break
109
+
110
+ normalized = {}
111
+ for key, value in checkpoint.items():
112
+ normalized[key[7:] if key.startswith("module.") else key] = value
113
+ return normalized
114
+
115
+
116
+ def _infer_feature_dim(state_dict: dict, default_dim: int) -> int:
117
+ head_weight = state_dict.get("head.weight")
118
+ if isinstance(head_weight, torch.Tensor) and head_weight.ndim == 2:
119
+ return int(head_weight.shape[1])
120
+ return default_dim
121
+
122
+
123
+ def _load_image_processor(model_name: str):
124
+ try:
125
+ return AutoImageProcessor.from_pretrained(model_name, local_files_only=True)
126
+ except Exception:
127
+ try:
128
+ return AutoImageProcessor.from_pretrained(model_name)
129
+ except Exception:
130
+ return None
131
+
132
+
133
+ def _load_backbone(model_name: str):
134
+ try:
135
+ return AutoModel.from_pretrained(model_name, local_files_only=True)
136
+ except Exception:
137
+ config = AutoConfig.from_pretrained(model_name)
138
+ return AutoModel.from_config(config)
139
+
140
+
141
+ class _PoolerLinearModel(nn.Module):
142
+ def __init__(self, backbone: nn.Module, feature_dim: int):
143
+ super().__init__()
144
+ self.backbone = backbone
145
+ self.head = nn.Linear(feature_dim, 2)
146
+
147
+ def forward(self, x):
148
+ with torch.no_grad():
149
+ outputs = self.backbone(x)
150
+ features = outputs.pooler_output.float()
151
+ return self.head(features)
152
+
153
+
154
+ class _ClsTokenLinearModel(nn.Module):
155
+ def __init__(self, backbone: nn.Module, feature_dim: int):
156
+ super().__init__()
157
+ self.backbone = backbone
158
+ self.head = nn.Linear(feature_dim, 2)
159
+
160
+ def forward(self, x):
161
+ with torch.no_grad():
162
+ outputs = self.backbone(x)
163
+ features = outputs.last_hidden_state[:, 0].float()
164
+ return self.head(features)
165
+
166
+
167
+ class _PELinearModel(nn.Module):
168
+ def __init__(self, backbone: nn.Module, feature_dim: int):
169
+ super().__init__()
170
+ self.backbone = backbone
171
+ self.head = nn.Linear(feature_dim, 2)
172
+
173
+ def forward(self, x):
174
+ with torch.no_grad():
175
+ features = self.backbone(x)
176
+ if isinstance(features, torch.Tensor):
177
+ features = features.float()
178
+ return self.head(features)
179
+
180
+
181
+ def _finalize_model(model: nn.Module, state_dict: dict, device=None) -> nn.Module:
182
+ model.load_state_dict(state_dict, strict=False)
183
+ model.to(_resolve_device(device))
184
+ model.eval()
185
+ return model
186
+
187
+
188
+ def _build_clip_transform(image_size: int, image_processor=None):
189
+ mean = [0.485, 0.456, 0.406]
190
+ std = [0.229, 0.224, 0.225]
191
+ if image_processor is not None:
192
+ mean = getattr(image_processor, "image_mean", mean)
193
+ std = getattr(image_processor, "image_std", std)
194
+ return transforms.Compose(
195
+ [
196
+ transforms.Resize(image_size, interpolation=transforms.InterpolationMode.BICUBIC),
197
+ transforms.CenterCrop(image_size),
198
+ transforms.ToTensor(),
199
+ transforms.Normalize(mean=mean, std=std),
200
+ ]
201
+ )
202
+
203
+
204
+ def _build_dino_transform():
205
+ return transforms.Compose(
206
+ [
207
+ transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC),
208
+ transforms.CenterCrop(224),
209
+ transforms.ToTensor(),
210
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
211
+ ]
212
+ )
213
+
214
+
215
+ def load_metacliplin(checkpoint_path: str | Path | None = None, device=None):
216
+ spec = MODEL_SPECS["metacliplin"]
217
+ checkpoint_path = checkpoint_path or default_checkpoint_path("metacliplin")
218
+ state_dict = _load_checkpoint(checkpoint_path)
219
+ feature_dim = _infer_feature_dim(state_dict, spec["feature_dim"])
220
+ image_processor = _load_image_processor(spec["hf_model"])
221
+ backbone = _load_backbone(spec["hf_model"])
222
+ model = _PoolerLinearModel(backbone.vision_model, feature_dim)
223
+ model = _finalize_model(model, state_dict, device=device)
224
+ return model, _build_clip_transform(spec["image_size"], image_processor)
225
+
226
+
227
+ def load_metaclip2lin(checkpoint_path: str | Path | None = None, device=None):
228
+ spec = MODEL_SPECS["metaclip2lin"]
229
+ checkpoint_path = checkpoint_path or default_checkpoint_path("metaclip2lin")
230
+ state_dict = _load_checkpoint(checkpoint_path)
231
+ feature_dim = _infer_feature_dim(state_dict, spec["feature_dim"])
232
+ image_processor = _load_image_processor(spec["hf_model"])
233
+ backbone = _load_backbone(spec["hf_model"])
234
+ model = _PoolerLinearModel(backbone.vision_model, feature_dim)
235
+ model = _finalize_model(model, state_dict, device=device)
236
+ return model, _build_clip_transform(spec["image_size"], image_processor)
237
+
238
+
239
+ def load_sigliplin(checkpoint_path: str | Path | None = None, device=None):
240
+ spec = MODEL_SPECS["sigliplin"]
241
+ checkpoint_path = checkpoint_path or default_checkpoint_path("sigliplin")
242
+ state_dict = _load_checkpoint(checkpoint_path)
243
+ feature_dim = _infer_feature_dim(state_dict, spec["feature_dim"])
244
+ image_processor = _load_image_processor(spec["hf_model"])
245
+ backbone = _load_backbone(spec["hf_model"])
246
+ model = _PoolerLinearModel(backbone.vision_model, feature_dim)
247
+ model = _finalize_model(model, state_dict, device=device)
248
+ return model, _build_clip_transform(spec["image_size"], image_processor)
249
+
250
+
251
+ def load_siglip2lin(checkpoint_path: str | Path | None = None, device=None):
252
+ spec = MODEL_SPECS["siglip2lin"]
253
+ checkpoint_path = checkpoint_path or default_checkpoint_path("siglip2lin")
254
+ state_dict = _load_checkpoint(checkpoint_path)
255
+ feature_dim = _infer_feature_dim(state_dict, spec["feature_dim"])
256
+ image_processor = _load_image_processor(spec["hf_model"])
257
+ backbone = _load_backbone(spec["hf_model"])
258
+ model = _PoolerLinearModel(backbone.vision_model, feature_dim)
259
+ model = _finalize_model(model, state_dict, device=device)
260
+ return model, _build_clip_transform(spec["image_size"], image_processor)
261
+
262
+
263
+ def load_dinov2lin(checkpoint_path: str | Path | None = None, device=None):
264
+ checkpoint_path = checkpoint_path or default_checkpoint_path("dinov2lin")
265
+ state_dict = _load_checkpoint(checkpoint_path)
266
+ feature_dim = _infer_feature_dim(state_dict, MODEL_SPECS["dinov2lin"]["feature_dim"])
267
+ if feature_dim == 1536:
268
+ candidates = ["facebook/dinov2-giant", "facebook/dinov2-large"]
269
+ elif feature_dim == 1024:
270
+ candidates = ["facebook/dinov2-large", "facebook/dinov2-base"]
271
+ elif feature_dim == 768:
272
+ candidates = ["facebook/dinov2-base", "facebook/dinov2-small"]
273
+ else:
274
+ candidates = ["facebook/dinov2-large"]
275
+
276
+ last_error = None
277
+ backbone = None
278
+ for candidate in candidates:
279
+ try:
280
+ backbone = _load_backbone(candidate)
281
+ break
282
+ except Exception as exc:
283
+ last_error = exc
284
+ if backbone is None:
285
+ raise RuntimeError(f"Failed to load DINOv2 backbone: {last_error}")
286
+
287
+ model = _ClsTokenLinearModel(backbone, feature_dim)
288
+ model = _finalize_model(model, state_dict, device=device)
289
+ return model, _build_dino_transform()
290
+
291
+
292
+ def load_dinov3lin(checkpoint_path: str | Path | None = None, device=None):
293
+ checkpoint_path = checkpoint_path or default_checkpoint_path("dinov3lin")
294
+ state_dict = _load_checkpoint(checkpoint_path)
295
+ feature_dim = _infer_feature_dim(state_dict, MODEL_SPECS["dinov3lin"]["feature_dim"])
296
+ backbone = _load_backbone(MODEL_SPECS["dinov3lin"]["hf_model"])
297
+ model = _ClsTokenLinearModel(backbone, feature_dim)
298
+ model = _finalize_model(model, state_dict, device=device)
299
+ return model, _build_dino_transform()
300
+
301
+
302
+ def load_pelin(checkpoint_path: str | Path | None = None, device=None):
303
+ checkpoint_path = checkpoint_path or default_checkpoint_path("pelin")
304
+ if str(ROOT) not in sys.path:
305
+ sys.path.insert(0, str(ROOT))
306
+
307
+ import core.vision_encoder.pe as pe
308
+ import core.vision_encoder.transforms as pe_transforms
309
+
310
+ state_dict = _load_checkpoint(checkpoint_path)
311
+ feature_dim = _infer_feature_dim(state_dict, MODEL_SPECS["pelin"]["feature_dim"])
312
+ clip_model = pe.CLIP.from_config("PE-Core-L14-336", pretrained=False)
313
+ model = _PELinearModel(clip_model.visual, feature_dim)
314
+ model = _finalize_model(model, state_dict, device=device)
315
+ return model, pe_transforms.get_image_transform(MODEL_SPECS["pelin"]["image_size"])
316
+
317
+
318
+ LOADERS: dict[str, Callable] = {
319
+ "metacliplin": load_metacliplin,
320
+ "metaclip2lin": load_metaclip2lin,
321
+ "sigliplin": load_sigliplin,
322
+ "siglip2lin": load_siglip2lin,
323
+ "pelin": load_pelin,
324
+ "dinov2lin": load_dinov2lin,
325
+ "dinov3lin": load_dinov3lin,
326
+ }
327
+
328
+
329
+ def load_model(model_name: str, checkpoint_path: str | Path | None = None, device=None):
330
+ model_name = canonical_model_name(model_name)
331
+ return LOADERS[model_name](checkpoint_path=checkpoint_path, device=device)
test_vfm_baselines.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Unified evaluation script for the 7 VFM baselines."""
3
+
4
+ from __future__ import annotations
5
+
6
+ import argparse
7
+ import json
8
+ from pathlib import Path
9
+
10
+ import numpy as np
11
+ import torch
12
+ from PIL import Image
13
+ from sklearn.metrics import accuracy_score, average_precision_score, roc_auc_score
14
+ from torch.utils.data import DataLoader, Dataset
15
+
16
+ from models import LOADERS, MODEL_SPECS, canonical_model_name, default_checkpoint_path, load_model
17
+
18
+ IMAGE_EXTENSIONS = (".jpg", ".jpeg", ".png", ".bmp", ".JPG", ".JPEG", ".PNG")
19
+
20
+
21
+ class BinaryFolderDataset(Dataset):
22
+ def __init__(self, real_dir: str, fake_dir: str, transform, max_samples: int | None = None):
23
+ self.transform = transform
24
+ real_paths = self._get_image_files(real_dir)
25
+ fake_paths = self._get_image_files(fake_dir)
26
+ if max_samples is not None:
27
+ real_paths = real_paths[:max_samples]
28
+ fake_paths = fake_paths[:max_samples]
29
+ self.image_paths = real_paths + fake_paths
30
+ self.labels = [0] * len(real_paths) + [1] * len(fake_paths)
31
+
32
+ @staticmethod
33
+ def _get_image_files(folder: str):
34
+ folder = Path(folder)
35
+ images = []
36
+ for extension in IMAGE_EXTENSIONS:
37
+ images.extend(folder.rglob(f"*{extension}"))
38
+ return sorted(images)
39
+
40
+ def __len__(self):
41
+ return len(self.image_paths)
42
+
43
+ def __getitem__(self, index):
44
+ image_path = self.image_paths[index]
45
+ image = Image.open(image_path).convert("RGB")
46
+ return self.transform(image), self.labels[index], str(image_path)
47
+
48
+
49
+ def evaluate(model, transform, real_dir: str, fake_dir: str, batch_size: int, num_workers: int, max_samples: int | None):
50
+ dataset = BinaryFolderDataset(real_dir, fake_dir, transform, max_samples=max_samples)
51
+ dataloader = DataLoader(
52
+ dataset,
53
+ batch_size=batch_size,
54
+ shuffle=False,
55
+ num_workers=num_workers,
56
+ pin_memory=torch.cuda.is_available(),
57
+ )
58
+
59
+ device = next(model.parameters()).device
60
+ y_true = []
61
+ y_prob = []
62
+ y_pred = []
63
+ paths = []
64
+
65
+ with torch.no_grad():
66
+ for images, labels, batch_paths in dataloader:
67
+ images = images.to(device)
68
+ logits = model(images)
69
+ probs = torch.softmax(logits, dim=1)[:, 1].cpu().numpy()
70
+ preds = (probs > 0.5).astype(int)
71
+
72
+ y_true.extend(labels.numpy().tolist())
73
+ y_prob.extend(probs.tolist())
74
+ y_pred.extend(preds.tolist())
75
+ paths.extend(batch_paths)
76
+
77
+ y_true = np.asarray(y_true)
78
+ y_prob = np.asarray(y_prob)
79
+ y_pred = np.asarray(y_pred)
80
+
81
+ metrics = {
82
+ "accuracy": float(accuracy_score(y_true, y_pred)),
83
+ "real_accuracy": float(accuracy_score(y_true[y_true == 0], y_pred[y_true == 0])),
84
+ "fake_accuracy": float(accuracy_score(y_true[y_true == 1], y_pred[y_true == 1])),
85
+ }
86
+ if len(np.unique(y_true)) > 1:
87
+ metrics["auc"] = float(roc_auc_score(y_true, y_prob))
88
+ metrics["ap"] = float(average_precision_score(y_true, y_prob))
89
+
90
+ samples = [
91
+ {
92
+ "path": path,
93
+ "label": int(label),
94
+ "prob_fake": float(prob),
95
+ "pred": int(pred),
96
+ }
97
+ for path, label, prob, pred in zip(paths, y_true, y_prob, y_pred)
98
+ ]
99
+ return {"metrics": metrics, "samples": samples}
100
+
101
+
102
+ def main():
103
+ parser = argparse.ArgumentParser()
104
+ parser.add_argument("--model", default="all", help="One of: all, metacliplin, metaclip2lin, sigliplin, siglip2lin, pelin, dinov2lin, dinov3lin")
105
+ parser.add_argument("--real-dir", required=True)
106
+ parser.add_argument("--fake-dir", required=True)
107
+ parser.add_argument("--checkpoint", default=None, help="Optional explicit checkpoint path for single-model evaluation")
108
+ parser.add_argument("--batch-size", type=int, default=8)
109
+ parser.add_argument("--num-workers", type=int, default=4)
110
+ parser.add_argument("--max-samples", type=int, default=None)
111
+ parser.add_argument("--device", default=None)
112
+ parser.add_argument("--save-json", default=None)
113
+ args = parser.parse_args()
114
+
115
+ model_names = list(LOADERS.keys()) if args.model == "all" else [canonical_model_name(args.model)]
116
+ results = {}
117
+
118
+ for model_name in model_names:
119
+ checkpoint = args.checkpoint if args.model != "all" and args.checkpoint else default_checkpoint_path(model_name)
120
+ checkpoint = Path(checkpoint)
121
+ try:
122
+ checkpoint_for_output = str(checkpoint.relative_to(Path(__file__).resolve().parent))
123
+ except ValueError:
124
+ checkpoint_for_output = str(checkpoint)
125
+ model, transform = load_model(model_name, checkpoint_path=checkpoint, device=args.device)
126
+ result = evaluate(
127
+ model=model,
128
+ transform=transform,
129
+ real_dir=args.real_dir,
130
+ fake_dir=args.fake_dir,
131
+ batch_size=args.batch_size,
132
+ num_workers=args.num_workers,
133
+ max_samples=args.max_samples,
134
+ )
135
+ results[model_name] = {
136
+ "paper_name": MODEL_SPECS[model_name]["paper_name"],
137
+ "checkpoint": checkpoint_for_output,
138
+ **result,
139
+ }
140
+
141
+ del model
142
+ if torch.cuda.is_available():
143
+ torch.cuda.empty_cache()
144
+
145
+ output = json.dumps(results, indent=2, ensure_ascii=False)
146
+ print(output)
147
+
148
+ if args.save_json:
149
+ Path(args.save_json).write_text(output + "\n", encoding="utf-8")
150
+
151
+
152
+ if __name__ == "__main__":
153
+ main()
weights/dinov2lin0.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3c8604c137ad296d9f6bbd239d03e792cca36b2503eb03cebc5ccb5abf740ebe
3
+ size 4546228799
weights/dinov3lin0.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:58e35c23fc4e6a279dadedac8191b4a409a760fbdc43837af9f8541a6f7b2fb9
3
+ size 26864441175
weights/metaclip2lin0.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e609f38a74ad280abe48dd0dc0111ef113f5f1cd8c4a3337a8346a22afbc5258
3
+ size 3685870062
weights/metacliplin0.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2deed4e9d96cb5f27df0579c57028e9b855162f3daf7326ec402b580e244194c
3
+ size 1261744353
weights/pelin0.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d871b813db9d1cf335ba0cde1701c72baf7bc80369238d881c55a676a59b24ff
3
+ size 1268731407
weights/siglip2lin0.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c6971aab31d8ff4c061f4ae330b49f38c3f45a96bfb99e756f65af72e8f3f3b7
3
+ size 2327586086
weights/sigliplin0.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eaf932895087718f284d880fccbc565cfafbed7b2a6c12b67a346e6c878c8ab3
3
+ size 632730704