Update all files for BitDance-ImageNet-diffusers
Browse files
BitDance_B_16x/transformer/gfq.py
ADDED
|
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Lookup Free Quantization
|
| 3 |
+
Proposed in https://arxiv.org/abs/2310.05737
|
| 4 |
+
|
| 5 |
+
In the simplest setup, each dimension is quantized into {-1, 1}.
|
| 6 |
+
An entropy penalty is used to encourage utilization.
|
| 7 |
+
|
| 8 |
+
Refer to
|
| 9 |
+
https://github.com/lucidrains/vector-quantize-pytorch/blob/master/vector_quantize_pytorch/lookup_free_quantization.py
|
| 10 |
+
https://github.com/theAdamColton/ijepa-enhanced/blob/7edef5f7288ae8f537f0db8a10044a2a487f70c9/ijepa_enhanced/lfq.py
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from math import log2, ceil
|
| 14 |
+
from collections import namedtuple
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
from torch import nn, einsum
|
| 18 |
+
import torch.nn.functional as F
|
| 19 |
+
from torch.nn import Module
|
| 20 |
+
|
| 21 |
+
from einops import rearrange, reduce, pack, unpack
|
| 22 |
+
|
| 23 |
+
# constants
|
| 24 |
+
LossBreakdown = namedtuple('LossBreakdown', ['per_sample_entropy', 'codebook_entropy', 'commitment', 'avg_probs'])
|
| 25 |
+
|
| 26 |
+
# helper functions
|
| 27 |
+
def exists(v):
|
| 28 |
+
return v is not None
|
| 29 |
+
|
| 30 |
+
def default(*args):
|
| 31 |
+
for arg in args:
|
| 32 |
+
if exists(arg):
|
| 33 |
+
return arg() if callable(arg) else arg
|
| 34 |
+
return None
|
| 35 |
+
|
| 36 |
+
def pack_one(t, pattern):
|
| 37 |
+
return pack([t], pattern)
|
| 38 |
+
|
| 39 |
+
def unpack_one(t, ps, pattern):
|
| 40 |
+
return unpack(t, ps, pattern)[0]
|
| 41 |
+
|
| 42 |
+
# entropy
|
| 43 |
+
def entropy(prob):
|
| 44 |
+
return (-prob * torch.log(prob + 1e-5)).sum(dim=-1)
|
| 45 |
+
|
| 46 |
+
# class
|
| 47 |
+
def mult_along_first_dims(x, y):
|
| 48 |
+
"""
|
| 49 |
+
returns x * y elementwise along the leading dimensions of y
|
| 50 |
+
"""
|
| 51 |
+
ndim_to_expand = x.ndim - y.ndim
|
| 52 |
+
for _ in range(ndim_to_expand):
|
| 53 |
+
y = y.unsqueeze(-1)
|
| 54 |
+
return x * y
|
| 55 |
+
|
| 56 |
+
def masked_mean(x, m):
|
| 57 |
+
"""
|
| 58 |
+
takes the mean of the elements of x that are not masked
|
| 59 |
+
the mean is taken along the shared leading dims of m
|
| 60 |
+
equivalent to: x[m].mean(tuple(range(m.ndim)))
|
| 61 |
+
|
| 62 |
+
The benefit of using masked_mean rather than using
|
| 63 |
+
tensor indexing is that masked_mean is much faster
|
| 64 |
+
for torch-compile on batches.
|
| 65 |
+
|
| 66 |
+
The drawback is larger floating point errors
|
| 67 |
+
"""
|
| 68 |
+
x = mult_along_first_dims(x, m)
|
| 69 |
+
x = x / m.sum()
|
| 70 |
+
return x.sum(tuple(range(m.ndim)))
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def entropy_loss(
|
| 74 |
+
logits,
|
| 75 |
+
mask=None,
|
| 76 |
+
temperature=0.01,
|
| 77 |
+
sample_minimization_weight=1.0,
|
| 78 |
+
batch_maximization_weight=1.0,
|
| 79 |
+
eps=1e-5,
|
| 80 |
+
):
|
| 81 |
+
"""
|
| 82 |
+
Entropy loss of unnormalized logits
|
| 83 |
+
|
| 84 |
+
logits: Affinities are over the last dimension
|
| 85 |
+
|
| 86 |
+
https://github.com/google-research/magvit/blob/05e8cfd6559c47955793d70602d62a2f9b0bdef5/videogvt/train_lib/losses.py#L279
|
| 87 |
+
LANGUAGE MODEL BEATS DIFFUSION — TOKENIZER IS KEY TO VISUAL GENERATION (2024)
|
| 88 |
+
"""
|
| 89 |
+
probs = F.softmax(logits / temperature, -1)
|
| 90 |
+
log_probs = F.log_softmax(logits / temperature + eps, -1)
|
| 91 |
+
|
| 92 |
+
if mask is not None:
|
| 93 |
+
# avg_probs = probs[mask].mean(tuple(range(probs.ndim - 1)))
|
| 94 |
+
# avg_probs = einx.mean("... D -> D", probs[mask])
|
| 95 |
+
|
| 96 |
+
avg_probs = masked_mean(probs, mask)
|
| 97 |
+
# avg_probs = einx.mean("... D -> D", avg_probs)
|
| 98 |
+
else:
|
| 99 |
+
avg_probs = reduce(probs, "... D -> D", "mean")
|
| 100 |
+
|
| 101 |
+
avg_entropy = -torch.sum(avg_probs * torch.log(avg_probs + eps))
|
| 102 |
+
|
| 103 |
+
sample_entropy = -torch.sum(probs * log_probs, -1)
|
| 104 |
+
if mask is not None:
|
| 105 |
+
# sample_entropy = sample_entropy[mask].mean()
|
| 106 |
+
sample_entropy = masked_mean(sample_entropy, mask).mean()
|
| 107 |
+
else:
|
| 108 |
+
sample_entropy = torch.mean(sample_entropy)
|
| 109 |
+
|
| 110 |
+
loss = (sample_minimization_weight * sample_entropy) - (
|
| 111 |
+
batch_maximization_weight * avg_entropy
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
return sample_entropy, avg_entropy, loss
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class GFQ(Module):
|
| 118 |
+
def __init__(
|
| 119 |
+
self,
|
| 120 |
+
*,
|
| 121 |
+
dim,
|
| 122 |
+
num_codebooks = 1,
|
| 123 |
+
sample_minimization_weight=1.0,
|
| 124 |
+
batch_maximization_weight=1.0,
|
| 125 |
+
):
|
| 126 |
+
super().__init__()
|
| 127 |
+
self.token_factorization = num_codebooks > 1
|
| 128 |
+
self.codebook_dim = dim // num_codebooks
|
| 129 |
+
self.codebook_size = 2 ** self.codebook_dim
|
| 130 |
+
self.dim = dim
|
| 131 |
+
self.num_codebooks = num_codebooks
|
| 132 |
+
self.vocab_size = num_codebooks * self.codebook_size
|
| 133 |
+
|
| 134 |
+
# for entropy loss
|
| 135 |
+
self.sample_minimization_weight = sample_minimization_weight
|
| 136 |
+
self.batch_maximization_weight = batch_maximization_weight
|
| 137 |
+
self.factorized_bits = [self.codebook_dim] * num_codebooks
|
| 138 |
+
for i, factorized_bit in enumerate(self.factorized_bits):
|
| 139 |
+
self.register_buffer(f"mask_{i}", 2 ** torch.arange(factorized_bit), persistent=False)
|
| 140 |
+
|
| 141 |
+
# codes
|
| 142 |
+
all_codes = torch.arange(self.codebook_size)
|
| 143 |
+
bits = self.indices_to_bits(all_codes)
|
| 144 |
+
codebook = bits * 2.0 - 1.0
|
| 145 |
+
self.register_buffer('codebook', codebook, persistent = False)
|
| 146 |
+
self.register_buffer('zero', torch.tensor(0.), persistent = False)
|
| 147 |
+
|
| 148 |
+
@property
|
| 149 |
+
def dtype(self):
|
| 150 |
+
return self.codebook.dtype
|
| 151 |
+
|
| 152 |
+
def indices_to_bits(self, x):
|
| 153 |
+
"""
|
| 154 |
+
x: long tensor of indices
|
| 155 |
+
|
| 156 |
+
returns big endian bits
|
| 157 |
+
"""
|
| 158 |
+
mask = 2 ** torch.arange(self.codebook_dim, device=x.device, dtype=torch.long)
|
| 159 |
+
x = (x.unsqueeze(-1) & mask) != 0 # x is now big endian bits, the last dimension being the bits
|
| 160 |
+
return x
|
| 161 |
+
|
| 162 |
+
def get_codebook_entry(self, x, bhwc, index_order): #0610
|
| 163 |
+
mask = getattr(self, f"mask_{index_order}") if self.token_factorization else self.mask
|
| 164 |
+
mask = mask.to(device=x.device, dtype=torch.long)
|
| 165 |
+
|
| 166 |
+
x = (x.unsqueeze(-1) & mask) != 0
|
| 167 |
+
x = x * 2.0 - 1.0 #back to the float
|
| 168 |
+
b, h, w, c = bhwc
|
| 169 |
+
x = rearrange(x, "b (h w) c -> b h w c", h=h, w=w, c=c)
|
| 170 |
+
x = rearrange(x, "b h w c -> b c h w") ## scale back
|
| 171 |
+
return x
|
| 172 |
+
|
| 173 |
+
def bits_to_indices(self, bits):
|
| 174 |
+
"""
|
| 175 |
+
bits: bool tensor of big endian bits, where the last dimension is the bit dimension
|
| 176 |
+
|
| 177 |
+
returns indices, which are long integers from 0 to self.codebook_size
|
| 178 |
+
"""
|
| 179 |
+
assert bits.shape[-1] == self.codebook_dim
|
| 180 |
+
indices = 2 ** torch.arange(
|
| 181 |
+
0,
|
| 182 |
+
self.codebook_dim,
|
| 183 |
+
1,
|
| 184 |
+
dtype=torch.long,
|
| 185 |
+
device=bits.device,
|
| 186 |
+
)
|
| 187 |
+
return (bits * indices).sum(-1)
|
| 188 |
+
|
| 189 |
+
def decode(self, x):
|
| 190 |
+
"""
|
| 191 |
+
x: ... NH
|
| 192 |
+
where NH is number of codebook heads
|
| 193 |
+
A longtensor of codebook indices, containing values from
|
| 194 |
+
0 to self.codebook_size
|
| 195 |
+
"""
|
| 196 |
+
x = self.indices_to_bits(x)
|
| 197 |
+
x = x.to(self.dtype) # to some sort of float
|
| 198 |
+
x = x * 2 - 1 # -1 or 1
|
| 199 |
+
x = rearrange(x, "... NC Z-> ... (NC Z)")
|
| 200 |
+
return x
|
| 201 |
+
|
| 202 |
+
def forward(
|
| 203 |
+
self,
|
| 204 |
+
x,
|
| 205 |
+
inv_temperature = 100.,
|
| 206 |
+
return_loss_breakdown = False,
|
| 207 |
+
mask = None,
|
| 208 |
+
return_loss = True,
|
| 209 |
+
):
|
| 210 |
+
"""
|
| 211 |
+
einstein notation
|
| 212 |
+
b - batch
|
| 213 |
+
n - sequence (or flattened spatial dimensions)
|
| 214 |
+
d - feature dimension, which is also log2(codebook size)
|
| 215 |
+
c - number of codebook dim
|
| 216 |
+
"""
|
| 217 |
+
x = rearrange(x, 'b d ... -> b ... d')
|
| 218 |
+
x, ps = pack_one(x, 'b * d')
|
| 219 |
+
x = rearrange(x, 'b n (c d) -> b n c d', c = self.num_codebooks) # split out number of codebooks
|
| 220 |
+
|
| 221 |
+
codebook_value = torch.Tensor([1.0]).to(device=x.device, dtype=x.dtype)
|
| 222 |
+
quantized = torch.where(x > 0, codebook_value, -codebook_value) # higher than 0 filled
|
| 223 |
+
|
| 224 |
+
# calculate indices
|
| 225 |
+
if self.token_factorization:
|
| 226 |
+
quantized = rearrange(quantized, 'b n c d -> b n 1 (c d)')
|
| 227 |
+
indices_list = []
|
| 228 |
+
begin = 0
|
| 229 |
+
end = 0
|
| 230 |
+
for i, factorized_bit in enumerate(self.factorized_bits):
|
| 231 |
+
end += factorized_bit
|
| 232 |
+
mask_name = f"mask_{i}"
|
| 233 |
+
mask = getattr(self, mask_name)
|
| 234 |
+
indices = reduce((quantized[..., begin:end] > 0).int() * mask.int(), "b n c d -> b n c", "sum")
|
| 235 |
+
indices_list.append(indices)
|
| 236 |
+
begin += factorized_bit
|
| 237 |
+
quantized = rearrange(quantized, 'b n 1 (c d) -> b n c d', c = self.num_codebooks)
|
| 238 |
+
else:
|
| 239 |
+
indices = reduce((quantized > 0).int() * self.mask.int(), 'b n c d -> b n c', 'sum')
|
| 240 |
+
|
| 241 |
+
# entropy aux loss
|
| 242 |
+
if self.training and return_loss:
|
| 243 |
+
logits = 2 * einsum('... i d, j d -> ... i j', x, self.codebook)
|
| 244 |
+
# the same as euclidean distance up to a constant
|
| 245 |
+
per_sample_entropy, codebook_entropy, entropy_aux_loss = entropy_loss(
|
| 246 |
+
logits = logits,
|
| 247 |
+
sample_minimization_weight = self.sample_minimization_weight,
|
| 248 |
+
batch_maximization_weight = self.batch_maximization_weight
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
avg_probs = self.zero
|
| 252 |
+
else:
|
| 253 |
+
per_sample_entropy = codebook_entropy = self.zero
|
| 254 |
+
entropy_aux_loss = self.zero
|
| 255 |
+
avg_probs = self.zero
|
| 256 |
+
|
| 257 |
+
# commit loss
|
| 258 |
+
if self.training:
|
| 259 |
+
commit_loss = F.mse_loss(x, quantized.detach(), reduction = 'none')
|
| 260 |
+
|
| 261 |
+
if exists(mask):
|
| 262 |
+
commit_loss = commit_loss[mask]
|
| 263 |
+
|
| 264 |
+
commit_loss = commit_loss.mean()
|
| 265 |
+
else:
|
| 266 |
+
commit_loss = self.zero
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
# use straight-through gradients (optionally with custom activation fn) if training
|
| 270 |
+
if self.training:
|
| 271 |
+
quantized = x + (quantized - x).detach() #transfer to quantized
|
| 272 |
+
|
| 273 |
+
# merge back codebook dim
|
| 274 |
+
quantized = rearrange(quantized, 'b n c d -> b n (c d)')
|
| 275 |
+
|
| 276 |
+
# reconstitute image or video dimensions
|
| 277 |
+
quantized = unpack_one(quantized, ps, 'b * d')
|
| 278 |
+
quantized = rearrange(quantized, 'b ... d -> b d ...')
|
| 279 |
+
|
| 280 |
+
if self.token_factorization:
|
| 281 |
+
indices_ = []
|
| 282 |
+
for i, indices in enumerate(indices_list):
|
| 283 |
+
indices = unpack_one(indices, ps, "b * c")
|
| 284 |
+
indices = indices.flatten()
|
| 285 |
+
indices_.append(indices)
|
| 286 |
+
indices = indices_
|
| 287 |
+
else:
|
| 288 |
+
indices = unpack_one(indices, ps, 'b * c')
|
| 289 |
+
indices = indices.flatten()
|
| 290 |
+
|
| 291 |
+
ret = (quantized, entropy_aux_loss, indices)
|
| 292 |
+
|
| 293 |
+
if not return_loss_breakdown:
|
| 294 |
+
return ret
|
| 295 |
+
|
| 296 |
+
return ret, LossBreakdown(per_sample_entropy, codebook_entropy, commit_loss, avg_probs)
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
if __name__ == "__main__":
|
| 300 |
+
quantizer = GFQ(
|
| 301 |
+
codebook_size = 2**18, # codebook size, must be a power of 2
|
| 302 |
+
dim = 18, # this is the input feature dimension, defaults to log2(codebook_size) if not defined
|
| 303 |
+
sample_minimization_weight = 1.0, # within entropy loss, how much weight to give to diversity of codes, taken from https://arxiv.org/abs/1911.05894
|
| 304 |
+
batch_maximization_weight = 1.0
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
image_feats = torch.randn(2, 18, 16, 16) #16 is dim, must be power of 2 of codebook_size
|
| 308 |
+
|
| 309 |
+
quantized, indices, entropy_aux_loss = quantizer(image_feats, inv_temperature=100.) # you may want to experiment with temperature
|
| 310 |
+
|
| 311 |
+
assert image_feats.shape == quantized.shape
|
| 312 |
+
assert (quantized == quantizer.indices_to_codes(indices)).all()
|