BiliSakura commited on
Commit
c0afd19
·
verified ·
1 Parent(s): dbbed72

Update all files for BitDance-ImageNet-diffusers

Browse files
Files changed (1) hide show
  1. BitDance_H_1x/transformer/gfq.py +312 -0
BitDance_H_1x/transformer/gfq.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Lookup Free Quantization
3
+ Proposed in https://arxiv.org/abs/2310.05737
4
+
5
+ In the simplest setup, each dimension is quantized into {-1, 1}.
6
+ An entropy penalty is used to encourage utilization.
7
+
8
+ Refer to
9
+ https://github.com/lucidrains/vector-quantize-pytorch/blob/master/vector_quantize_pytorch/lookup_free_quantization.py
10
+ https://github.com/theAdamColton/ijepa-enhanced/blob/7edef5f7288ae8f537f0db8a10044a2a487f70c9/ijepa_enhanced/lfq.py
11
+ """
12
+
13
+ from math import log2, ceil
14
+ from collections import namedtuple
15
+
16
+ import torch
17
+ from torch import nn, einsum
18
+ import torch.nn.functional as F
19
+ from torch.nn import Module
20
+
21
+ from einops import rearrange, reduce, pack, unpack
22
+
23
+ # constants
24
+ LossBreakdown = namedtuple('LossBreakdown', ['per_sample_entropy', 'codebook_entropy', 'commitment', 'avg_probs'])
25
+
26
+ # helper functions
27
+ def exists(v):
28
+ return v is not None
29
+
30
+ def default(*args):
31
+ for arg in args:
32
+ if exists(arg):
33
+ return arg() if callable(arg) else arg
34
+ return None
35
+
36
+ def pack_one(t, pattern):
37
+ return pack([t], pattern)
38
+
39
+ def unpack_one(t, ps, pattern):
40
+ return unpack(t, ps, pattern)[0]
41
+
42
+ # entropy
43
+ def entropy(prob):
44
+ return (-prob * torch.log(prob + 1e-5)).sum(dim=-1)
45
+
46
+ # class
47
+ def mult_along_first_dims(x, y):
48
+ """
49
+ returns x * y elementwise along the leading dimensions of y
50
+ """
51
+ ndim_to_expand = x.ndim - y.ndim
52
+ for _ in range(ndim_to_expand):
53
+ y = y.unsqueeze(-1)
54
+ return x * y
55
+
56
+ def masked_mean(x, m):
57
+ """
58
+ takes the mean of the elements of x that are not masked
59
+ the mean is taken along the shared leading dims of m
60
+ equivalent to: x[m].mean(tuple(range(m.ndim)))
61
+
62
+ The benefit of using masked_mean rather than using
63
+ tensor indexing is that masked_mean is much faster
64
+ for torch-compile on batches.
65
+
66
+ The drawback is larger floating point errors
67
+ """
68
+ x = mult_along_first_dims(x, m)
69
+ x = x / m.sum()
70
+ return x.sum(tuple(range(m.ndim)))
71
+
72
+
73
+ def entropy_loss(
74
+ logits,
75
+ mask=None,
76
+ temperature=0.01,
77
+ sample_minimization_weight=1.0,
78
+ batch_maximization_weight=1.0,
79
+ eps=1e-5,
80
+ ):
81
+ """
82
+ Entropy loss of unnormalized logits
83
+
84
+ logits: Affinities are over the last dimension
85
+
86
+ https://github.com/google-research/magvit/blob/05e8cfd6559c47955793d70602d62a2f9b0bdef5/videogvt/train_lib/losses.py#L279
87
+ LANGUAGE MODEL BEATS DIFFUSION — TOKENIZER IS KEY TO VISUAL GENERATION (2024)
88
+ """
89
+ probs = F.softmax(logits / temperature, -1)
90
+ log_probs = F.log_softmax(logits / temperature + eps, -1)
91
+
92
+ if mask is not None:
93
+ # avg_probs = probs[mask].mean(tuple(range(probs.ndim - 1)))
94
+ # avg_probs = einx.mean("... D -> D", probs[mask])
95
+
96
+ avg_probs = masked_mean(probs, mask)
97
+ # avg_probs = einx.mean("... D -> D", avg_probs)
98
+ else:
99
+ avg_probs = reduce(probs, "... D -> D", "mean")
100
+
101
+ avg_entropy = -torch.sum(avg_probs * torch.log(avg_probs + eps))
102
+
103
+ sample_entropy = -torch.sum(probs * log_probs, -1)
104
+ if mask is not None:
105
+ # sample_entropy = sample_entropy[mask].mean()
106
+ sample_entropy = masked_mean(sample_entropy, mask).mean()
107
+ else:
108
+ sample_entropy = torch.mean(sample_entropy)
109
+
110
+ loss = (sample_minimization_weight * sample_entropy) - (
111
+ batch_maximization_weight * avg_entropy
112
+ )
113
+
114
+ return sample_entropy, avg_entropy, loss
115
+
116
+
117
+ class GFQ(Module):
118
+ def __init__(
119
+ self,
120
+ *,
121
+ dim,
122
+ num_codebooks = 1,
123
+ sample_minimization_weight=1.0,
124
+ batch_maximization_weight=1.0,
125
+ ):
126
+ super().__init__()
127
+ self.token_factorization = num_codebooks > 1
128
+ self.codebook_dim = dim // num_codebooks
129
+ self.codebook_size = 2 ** self.codebook_dim
130
+ self.dim = dim
131
+ self.num_codebooks = num_codebooks
132
+ self.vocab_size = num_codebooks * self.codebook_size
133
+
134
+ # for entropy loss
135
+ self.sample_minimization_weight = sample_minimization_weight
136
+ self.batch_maximization_weight = batch_maximization_weight
137
+ self.factorized_bits = [self.codebook_dim] * num_codebooks
138
+ for i, factorized_bit in enumerate(self.factorized_bits):
139
+ self.register_buffer(f"mask_{i}", 2 ** torch.arange(factorized_bit), persistent=False)
140
+
141
+ # codes
142
+ all_codes = torch.arange(self.codebook_size)
143
+ bits = self.indices_to_bits(all_codes)
144
+ codebook = bits * 2.0 - 1.0
145
+ self.register_buffer('codebook', codebook, persistent = False)
146
+ self.register_buffer('zero', torch.tensor(0.), persistent = False)
147
+
148
+ @property
149
+ def dtype(self):
150
+ return self.codebook.dtype
151
+
152
+ def indices_to_bits(self, x):
153
+ """
154
+ x: long tensor of indices
155
+
156
+ returns big endian bits
157
+ """
158
+ mask = 2 ** torch.arange(self.codebook_dim, device=x.device, dtype=torch.long)
159
+ x = (x.unsqueeze(-1) & mask) != 0 # x is now big endian bits, the last dimension being the bits
160
+ return x
161
+
162
+ def get_codebook_entry(self, x, bhwc, index_order): #0610
163
+ mask = getattr(self, f"mask_{index_order}") if self.token_factorization else self.mask
164
+ mask = mask.to(device=x.device, dtype=torch.long)
165
+
166
+ x = (x.unsqueeze(-1) & mask) != 0
167
+ x = x * 2.0 - 1.0 #back to the float
168
+ b, h, w, c = bhwc
169
+ x = rearrange(x, "b (h w) c -> b h w c", h=h, w=w, c=c)
170
+ x = rearrange(x, "b h w c -> b c h w") ## scale back
171
+ return x
172
+
173
+ def bits_to_indices(self, bits):
174
+ """
175
+ bits: bool tensor of big endian bits, where the last dimension is the bit dimension
176
+
177
+ returns indices, which are long integers from 0 to self.codebook_size
178
+ """
179
+ assert bits.shape[-1] == self.codebook_dim
180
+ indices = 2 ** torch.arange(
181
+ 0,
182
+ self.codebook_dim,
183
+ 1,
184
+ dtype=torch.long,
185
+ device=bits.device,
186
+ )
187
+ return (bits * indices).sum(-1)
188
+
189
+ def decode(self, x):
190
+ """
191
+ x: ... NH
192
+ where NH is number of codebook heads
193
+ A longtensor of codebook indices, containing values from
194
+ 0 to self.codebook_size
195
+ """
196
+ x = self.indices_to_bits(x)
197
+ x = x.to(self.dtype) # to some sort of float
198
+ x = x * 2 - 1 # -1 or 1
199
+ x = rearrange(x, "... NC Z-> ... (NC Z)")
200
+ return x
201
+
202
+ def forward(
203
+ self,
204
+ x,
205
+ inv_temperature = 100.,
206
+ return_loss_breakdown = False,
207
+ mask = None,
208
+ return_loss = True,
209
+ ):
210
+ """
211
+ einstein notation
212
+ b - batch
213
+ n - sequence (or flattened spatial dimensions)
214
+ d - feature dimension, which is also log2(codebook size)
215
+ c - number of codebook dim
216
+ """
217
+ x = rearrange(x, 'b d ... -> b ... d')
218
+ x, ps = pack_one(x, 'b * d')
219
+ x = rearrange(x, 'b n (c d) -> b n c d', c = self.num_codebooks) # split out number of codebooks
220
+
221
+ codebook_value = torch.Tensor([1.0]).to(device=x.device, dtype=x.dtype)
222
+ quantized = torch.where(x > 0, codebook_value, -codebook_value) # higher than 0 filled
223
+
224
+ # calculate indices
225
+ if self.token_factorization:
226
+ quantized = rearrange(quantized, 'b n c d -> b n 1 (c d)')
227
+ indices_list = []
228
+ begin = 0
229
+ end = 0
230
+ for i, factorized_bit in enumerate(self.factorized_bits):
231
+ end += factorized_bit
232
+ mask_name = f"mask_{i}"
233
+ mask = getattr(self, mask_name)
234
+ indices = reduce((quantized[..., begin:end] > 0).int() * mask.int(), "b n c d -> b n c", "sum")
235
+ indices_list.append(indices)
236
+ begin += factorized_bit
237
+ quantized = rearrange(quantized, 'b n 1 (c d) -> b n c d', c = self.num_codebooks)
238
+ else:
239
+ indices = reduce((quantized > 0).int() * self.mask.int(), 'b n c d -> b n c', 'sum')
240
+
241
+ # entropy aux loss
242
+ if self.training and return_loss:
243
+ logits = 2 * einsum('... i d, j d -> ... i j', x, self.codebook)
244
+ # the same as euclidean distance up to a constant
245
+ per_sample_entropy, codebook_entropy, entropy_aux_loss = entropy_loss(
246
+ logits = logits,
247
+ sample_minimization_weight = self.sample_minimization_weight,
248
+ batch_maximization_weight = self.batch_maximization_weight
249
+ )
250
+
251
+ avg_probs = self.zero
252
+ else:
253
+ per_sample_entropy = codebook_entropy = self.zero
254
+ entropy_aux_loss = self.zero
255
+ avg_probs = self.zero
256
+
257
+ # commit loss
258
+ if self.training:
259
+ commit_loss = F.mse_loss(x, quantized.detach(), reduction = 'none')
260
+
261
+ if exists(mask):
262
+ commit_loss = commit_loss[mask]
263
+
264
+ commit_loss = commit_loss.mean()
265
+ else:
266
+ commit_loss = self.zero
267
+
268
+
269
+ # use straight-through gradients (optionally with custom activation fn) if training
270
+ if self.training:
271
+ quantized = x + (quantized - x).detach() #transfer to quantized
272
+
273
+ # merge back codebook dim
274
+ quantized = rearrange(quantized, 'b n c d -> b n (c d)')
275
+
276
+ # reconstitute image or video dimensions
277
+ quantized = unpack_one(quantized, ps, 'b * d')
278
+ quantized = rearrange(quantized, 'b ... d -> b d ...')
279
+
280
+ if self.token_factorization:
281
+ indices_ = []
282
+ for i, indices in enumerate(indices_list):
283
+ indices = unpack_one(indices, ps, "b * c")
284
+ indices = indices.flatten()
285
+ indices_.append(indices)
286
+ indices = indices_
287
+ else:
288
+ indices = unpack_one(indices, ps, 'b * c')
289
+ indices = indices.flatten()
290
+
291
+ ret = (quantized, entropy_aux_loss, indices)
292
+
293
+ if not return_loss_breakdown:
294
+ return ret
295
+
296
+ return ret, LossBreakdown(per_sample_entropy, codebook_entropy, commit_loss, avg_probs)
297
+
298
+
299
+ if __name__ == "__main__":
300
+ quantizer = GFQ(
301
+ codebook_size = 2**18, # codebook size, must be a power of 2
302
+ dim = 18, # this is the input feature dimension, defaults to log2(codebook_size) if not defined
303
+ sample_minimization_weight = 1.0, # within entropy loss, how much weight to give to diversity of codes, taken from https://arxiv.org/abs/1911.05894
304
+ batch_maximization_weight = 1.0
305
+ )
306
+
307
+ image_feats = torch.randn(2, 18, 16, 16) #16 is dim, must be power of 2 of codebook_size
308
+
309
+ quantized, indices, entropy_aux_loss = quantizer(image_feats, inv_temperature=100.) # you may want to experiment with temperature
310
+
311
+ assert image_feats.shape == quantized.shape
312
+ assert (quantized == quantizer.indices_to_codes(indices)).all()