BiliSakura commited on
Commit
d63a99b
·
verified ·
1 Parent(s): 581cce9

Update all files for BitDance-ImageNet-diffusers

Browse files
Files changed (1) hide show
  1. BitDance_B_16x/transformer/src/qae.py +382 -0
BitDance_B_16x/transformer/src/qae.py ADDED
@@ -0,0 +1,382 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from einops import rearrange
4
+
5
+ from .gfq import GFQ
6
+
7
+ def swish(x):
8
+ # swish
9
+ return x*torch.sigmoid(x)
10
+
11
+ class ResBlock(nn.Module):
12
+ def __init__(self,
13
+ in_filters,
14
+ out_filters,
15
+ use_conv_shortcut = False,
16
+ use_agn = False,
17
+ ) -> None:
18
+ super().__init__()
19
+
20
+ self.in_filters = in_filters
21
+ self.out_filters = out_filters
22
+ self.use_conv_shortcut = use_conv_shortcut
23
+ self.use_agn = use_agn
24
+
25
+ if not use_agn: ## agn is GroupNorm likewise skip it if has agn before
26
+ self.norm1 = nn.GroupNorm(32, in_filters, eps=1e-6)
27
+ self.norm2 = nn.GroupNorm(32, out_filters, eps=1e-6)
28
+
29
+ self.conv1 = nn.Conv2d(in_filters, out_filters, kernel_size=(3, 3), padding=1, bias=False)
30
+ self.conv2 = nn.Conv2d(out_filters, out_filters, kernel_size=(3, 3), padding=1, bias=False)
31
+
32
+ if in_filters != out_filters:
33
+ if self.use_conv_shortcut:
34
+ self.conv_shortcut = nn.Conv2d(in_filters, out_filters, kernel_size=(3, 3), padding=1, bias=False)
35
+ else:
36
+ self.nin_shortcut = nn.Conv2d(in_filters, out_filters, kernel_size=(1, 1), padding=0, bias=False)
37
+
38
+
39
+ def forward(self, x, **kwargs):
40
+ residual = x
41
+
42
+ if not self.use_agn:
43
+ x = self.norm1(x)
44
+ x = swish(x)
45
+ x = self.conv1(x)
46
+ x = self.norm2(x)
47
+ x = swish(x)
48
+ x = self.conv2(x)
49
+ if self.in_filters != self.out_filters:
50
+ if self.use_conv_shortcut:
51
+ residual = self.conv_shortcut(residual)
52
+ else:
53
+ residual = self.nin_shortcut(residual)
54
+
55
+ return x + residual
56
+
57
+ class Encoder(nn.Module):
58
+ def __init__(self, *, ch, out_ch, in_channels, num_res_blocks, z_channels, ch_mult=(1, 2, 2, 4),
59
+ resolution=None, double_z=False,
60
+ ):
61
+ super().__init__()
62
+
63
+ self.in_channels = in_channels
64
+ self.z_channels = z_channels
65
+ self.resolution = resolution
66
+
67
+ self.num_res_blocks = num_res_blocks
68
+ self.num_blocks = len(ch_mult)
69
+
70
+ self.conv_in = nn.Conv2d(in_channels,
71
+ ch,
72
+ kernel_size=(3, 3),
73
+ padding=1,
74
+ bias=False
75
+ )
76
+
77
+ ## construct the model
78
+ self.down = nn.ModuleList()
79
+
80
+ in_ch_mult = (1,)+tuple(ch_mult)
81
+ for i_level in range(self.num_blocks):
82
+ block = nn.ModuleList()
83
+ block_in = ch*in_ch_mult[i_level] #[1, 1, 2, 2, 4]
84
+ block_out = ch*ch_mult[i_level] #[1, 2, 2, 4]
85
+ for _ in range(self.num_res_blocks):
86
+ block.append(ResBlock(block_in, block_out))
87
+ block_in = block_out
88
+
89
+ down = nn.Module()
90
+ down.block = block
91
+ if i_level < self.num_blocks - 1:
92
+ down.downsample = nn.Conv2d(block_out, block_out, kernel_size=(3, 3), stride=(2, 2), padding=1)
93
+
94
+ self.down.append(down)
95
+
96
+ ### mid
97
+ self.mid_block = nn.ModuleList()
98
+ for res_idx in range(self.num_res_blocks):
99
+ self.mid_block.append(ResBlock(block_in, block_in))
100
+
101
+ ### end
102
+ self.norm_out = nn.GroupNorm(32, block_out, eps=1e-6)
103
+ self.conv_out = nn.Conv2d(block_out, z_channels, kernel_size=(1, 1))
104
+
105
+ def forward(self, x):
106
+
107
+ ## down
108
+ x = self.conv_in(x)
109
+ for i_level in range(self.num_blocks):
110
+ for i_block in range(self.num_res_blocks):
111
+ x = self.down[i_level].block[i_block](x)
112
+
113
+ if i_level < self.num_blocks - 1:
114
+ x = self.down[i_level].downsample(x)
115
+
116
+ ## mid
117
+ for res in range(self.num_res_blocks):
118
+ x = self.mid_block[res](x)
119
+
120
+
121
+ x = self.norm_out(x)
122
+ x = swish(x)
123
+ x = self.conv_out(x)
124
+
125
+ return x
126
+
127
+ class Decoder(nn.Module):
128
+ def __init__(self, *, ch, out_ch, in_channels, num_res_blocks, z_channels, ch_mult=(1, 2, 2, 4),
129
+ resolution=None, double_z=False,) -> None:
130
+ super().__init__()
131
+
132
+ self.ch = ch
133
+ self.num_blocks = len(ch_mult)
134
+ self.num_res_blocks = num_res_blocks
135
+ self.resolution = resolution
136
+ self.in_channels = in_channels
137
+
138
+ block_in = ch*ch_mult[self.num_blocks-1]
139
+
140
+ self.conv_in = nn.Conv2d(
141
+ z_channels, block_in, kernel_size=(3, 3), padding=1, bias=True
142
+ )
143
+
144
+ self.mid_block = nn.ModuleList()
145
+ for res_idx in range(self.num_res_blocks):
146
+ self.mid_block.append(ResBlock(block_in, block_in))
147
+
148
+ self.up = nn.ModuleList()
149
+
150
+ self.adaptive = nn.ModuleList()
151
+
152
+ for i_level in reversed(range(self.num_blocks)):
153
+ block = nn.ModuleList()
154
+ block_out = ch*ch_mult[i_level]
155
+ self.adaptive.insert(0, AdaptiveGroupNorm(z_channels, block_in))
156
+ for i_block in range(self.num_res_blocks):
157
+ block.append(ResBlock(block_in, block_out))
158
+ block_in = block_out
159
+
160
+ up = nn.Module()
161
+ up.block = block
162
+ if i_level > 0:
163
+ up.upsample = Upsampler(block_in)
164
+ self.up.insert(0, up)
165
+
166
+ self.norm_out = nn.GroupNorm(32, block_in, eps=1e-6)
167
+
168
+ self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=(3, 3), padding=1)
169
+
170
+ def forward(self, z):
171
+
172
+ style = z.clone() #for adaptive groupnorm
173
+
174
+ z = self.conv_in(z)
175
+
176
+ ## mid
177
+ for res in range(self.num_res_blocks):
178
+ z = self.mid_block[res](z)
179
+
180
+ ## upsample
181
+ for i_level in reversed(range(self.num_blocks)):
182
+ ### pass in each resblock first adaGN
183
+ z = self.adaptive[i_level](z, style)
184
+ for i_block in range(self.num_res_blocks):
185
+ z = self.up[i_level].block[i_block](z)
186
+
187
+ if i_level > 0:
188
+ z = self.up[i_level].upsample(z)
189
+
190
+ z = self.norm_out(z)
191
+ z = swish(z)
192
+ z = self.conv_out(z)
193
+
194
+ return z
195
+
196
+ def depth_to_space(x: torch.Tensor, block_size: int) -> torch.Tensor:
197
+ """ Depth-to-Space DCR mode (depth-column-row) core implementation.
198
+
199
+ Args:
200
+ x (torch.Tensor): input tensor. The channels-first (*CHW) layout is supported.
201
+ block_size (int): block side size
202
+ """
203
+ # check inputs
204
+ if x.dim() < 3:
205
+ raise ValueError(
206
+ f"Expecting a channels-first (*CHW) tensor of at least 3 dimensions"
207
+ )
208
+ c, h, w = x.shape[-3:]
209
+
210
+ s = block_size**2
211
+ if c % s != 0:
212
+ raise ValueError(
213
+ f"Expecting a channels-first (*CHW) tensor with C divisible by {s}, but got C={c} channels"
214
+ )
215
+
216
+ outer_dims = x.shape[:-3]
217
+
218
+ # splitting two additional dimensions from the channel dimension
219
+ x = x.view(-1, block_size, block_size, c // s, h, w)
220
+
221
+ # putting the two new dimensions along H and W
222
+ x = x.permute(0, 3, 4, 1, 5, 2)
223
+
224
+ # merging the two new dimensions with H and W
225
+ x = x.contiguous().view(*outer_dims, c // s, h * block_size,
226
+ w * block_size)
227
+
228
+ return x
229
+
230
+ class Upsampler(nn.Module):
231
+ def __init__(
232
+ self,
233
+ dim,
234
+ dim_out = None
235
+ ):
236
+ super().__init__()
237
+ dim_out = dim * 4
238
+ self.conv1 = nn.Conv2d(dim, dim_out, (3, 3), padding=1)
239
+ self.depth2space = depth_to_space
240
+
241
+ def forward(self, x):
242
+ """
243
+ input_image: [B C H W]
244
+ """
245
+ out = self.conv1(x)
246
+ out = self.depth2space(out, block_size=2)
247
+ return out
248
+
249
+ class AdaptiveGroupNorm(nn.Module):
250
+ def __init__(self, z_channel, in_filters, num_groups=32, eps=1e-6):
251
+ super().__init__()
252
+ self.gn = nn.GroupNorm(num_groups=32, num_channels=in_filters, eps=eps, affine=False)
253
+ # self.lin = nn.Linear(z_channels, in_filters * 2)
254
+ self.gamma = nn.Linear(z_channel, in_filters)
255
+ self.beta = nn.Linear(z_channel, in_filters)
256
+ self.eps = eps
257
+
258
+ def forward(self, x, quantizer):
259
+ B, C, _, _ = x.shape
260
+ # quantizer = F.adaptive_avg_pool2d(quantizer, (1, 1))
261
+ ### calcuate var for scale
262
+ scale = rearrange(quantizer, "b c h w -> b c (h w)")
263
+ scale = scale.var(dim=-1) + self.eps #not unbias
264
+ scale = scale.sqrt()
265
+ scale = self.gamma(scale).view(B, C, 1, 1)
266
+
267
+ ### calculate mean for bias
268
+ bias = rearrange(quantizer, "b c h w -> b c (h w)")
269
+ bias = bias.mean(dim=-1)
270
+ bias = self.beta(bias).view(B, C, 1, 1)
271
+
272
+ x = self.gn(x)
273
+ x = scale * x + bias
274
+
275
+ return x
276
+
277
+ class GANDecoder(nn.Module):
278
+ def __init__(self, *, ch, out_ch, in_channels, num_res_blocks, z_channels, ch_mult=(1, 2, 2, 4),
279
+ resolution=None, double_z=False,) -> None:
280
+ super().__init__()
281
+
282
+ self.ch = ch
283
+ self.num_blocks = len(ch_mult)
284
+ self.num_res_blocks = num_res_blocks
285
+ self.resolution = resolution
286
+ self.in_channels = in_channels
287
+
288
+ block_in = ch*ch_mult[self.num_blocks-1]
289
+
290
+ self.conv_in = nn.Conv2d(
291
+ z_channels * 2, block_in, kernel_size=(3, 3), padding=1, bias=True
292
+ )
293
+
294
+ self.mid_block = nn.ModuleList()
295
+ for res_idx in range(self.num_res_blocks):
296
+ self.mid_block.append(ResBlock(block_in, block_in))
297
+
298
+ self.up = nn.ModuleList()
299
+
300
+ self.adaptive = nn.ModuleList()
301
+
302
+ for i_level in reversed(range(self.num_blocks)):
303
+ block = nn.ModuleList()
304
+ block_out = ch*ch_mult[i_level]
305
+ self.adaptive.insert(0, AdaptiveGroupNorm(z_channels, block_in))
306
+ for i_block in range(self.num_res_blocks):
307
+ # if i_block == 0:
308
+ # block.append(ResBlock(block_in, block_out, use_agn=True))
309
+ # else:
310
+ block.append(ResBlock(block_in, block_out))
311
+ block_in = block_out
312
+
313
+ up = nn.Module()
314
+ up.block = block
315
+ if i_level > 0:
316
+ up.upsample = Upsampler(block_in)
317
+ self.up.insert(0, up)
318
+
319
+ self.norm_out = nn.GroupNorm(32, block_in, eps=1e-6)
320
+
321
+ self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=(3, 3), padding=1)
322
+
323
+ def forward(self, z):
324
+
325
+ style = z.clone() #for adaptive groupnorm
326
+
327
+ noise = torch.randn_like(z).to(z.device) #generate noise
328
+ z = torch.cat([z, noise], dim=1) #concat noise to the style vector
329
+ z = self.conv_in(z)
330
+
331
+ ## mid
332
+ for res in range(self.num_res_blocks):
333
+ z = self.mid_block[res](z)
334
+
335
+ ## upsample
336
+ for i_level in reversed(range(self.num_blocks)):
337
+ ### pass in each resblock first adaGN
338
+ z = self.adaptive[i_level](z, style)
339
+ for i_block in range(self.num_res_blocks):
340
+ z = self.up[i_level].block[i_block](z)
341
+
342
+ if i_level > 0:
343
+ z = self.up[i_level].upsample(z)
344
+
345
+ z = self.norm_out(z)
346
+ z = swish(z)
347
+ z = self.conv_out(z)
348
+
349
+ return z
350
+
351
+
352
+ class VQModel(nn.Module):
353
+ def __init__(self,
354
+ ddconfig,
355
+ num_codebooks = 1,
356
+ sample_minimization_weight=1,
357
+ batch_maximization_weight=1,
358
+ gan_decoder = False,
359
+ # ckpt_path = None,
360
+ ):
361
+ super().__init__()
362
+ self.encoder = Encoder(**ddconfig)
363
+ self.decoder = GANDecoder(**ddconfig) if gan_decoder else Decoder(**ddconfig)
364
+ self.quantize = GFQ(dim=ddconfig.get("z_channels", 32),
365
+ num_codebooks=num_codebooks,
366
+ sample_minimization_weight=sample_minimization_weight,
367
+ batch_maximization_weight=batch_maximization_weight,
368
+ )
369
+
370
+ def encode(self, x):
371
+ h = self.encoder(x)
372
+ (quant, emb_loss, info), loss_breakdown = self.quantize(h, return_loss_breakdown=True)
373
+ return quant, emb_loss, info, loss_breakdown
374
+
375
+ def decode(self, quant):
376
+ dec = self.decoder(quant)
377
+ return dec
378
+
379
+ def forward(self, input):
380
+ quant, _, _, loss_break = self.encode(input)
381
+ dec = self.decode(quant)
382
+ return dec, loss_break