BiliSakura commited on
Commit
dd6caba
·
verified ·
1 Parent(s): a296060

Update all files for BitDance-Tokenizer-diffusers

Browse files
bitdance_diffusers/modeling_autoencoder.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, Dict, Optional, Sequence
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from einops import rearrange
8
+ from torch import nn
9
+
10
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
11
+ from diffusers.models.modeling_utils import ModelMixin
12
+
13
+
14
+ def swish(x: torch.Tensor) -> torch.Tensor:
15
+ return x * torch.sigmoid(x)
16
+
17
+
18
+ class ResBlock(nn.Module):
19
+ def __init__(
20
+ self,
21
+ in_filters: int,
22
+ out_filters: int,
23
+ use_conv_shortcut: bool = False,
24
+ use_agn: bool = False,
25
+ ) -> None:
26
+ super().__init__()
27
+ self.in_filters = in_filters
28
+ self.out_filters = out_filters
29
+ self.use_conv_shortcut = use_conv_shortcut
30
+ self.use_agn = use_agn
31
+
32
+ if not use_agn:
33
+ self.norm1 = nn.GroupNorm(32, in_filters, eps=1e-6)
34
+ self.norm2 = nn.GroupNorm(32, out_filters, eps=1e-6)
35
+
36
+ self.conv1 = nn.Conv2d(in_filters, out_filters, kernel_size=3, padding=1, bias=False)
37
+ self.conv2 = nn.Conv2d(out_filters, out_filters, kernel_size=3, padding=1, bias=False)
38
+
39
+ if in_filters != out_filters:
40
+ if use_conv_shortcut:
41
+ self.conv_shortcut = nn.Conv2d(in_filters, out_filters, kernel_size=3, padding=1, bias=False)
42
+ else:
43
+ self.nin_shortcut = nn.Conv2d(in_filters, out_filters, kernel_size=1, padding=0, bias=False)
44
+
45
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
46
+ residual = x
47
+ if not self.use_agn:
48
+ x = self.norm1(x)
49
+ x = swish(x)
50
+ x = self.conv1(x)
51
+ x = self.norm2(x)
52
+ x = swish(x)
53
+ x = self.conv2(x)
54
+
55
+ if self.in_filters != self.out_filters:
56
+ if self.use_conv_shortcut:
57
+ residual = self.conv_shortcut(residual)
58
+ else:
59
+ residual = self.nin_shortcut(residual)
60
+
61
+ return x + residual
62
+
63
+
64
+ class Encoder(nn.Module):
65
+ def __init__(
66
+ self,
67
+ *,
68
+ ch: int,
69
+ out_ch: int,
70
+ in_channels: int,
71
+ num_res_blocks: int,
72
+ z_channels: int,
73
+ ch_mult: Sequence[int] = (1, 2, 2, 4),
74
+ resolution: Optional[int] = None,
75
+ double_z: bool = False,
76
+ ) -> None:
77
+ super().__init__()
78
+ del out_ch, double_z
79
+ self.in_channels = in_channels
80
+ self.z_channels = z_channels
81
+ self.resolution = resolution
82
+ self.num_res_blocks = num_res_blocks
83
+ self.num_blocks = len(ch_mult)
84
+
85
+ self.conv_in = nn.Conv2d(in_channels, ch, kernel_size=3, padding=1, bias=False)
86
+ self.down = nn.ModuleList()
87
+
88
+ in_ch_mult = (1,) + tuple(ch_mult)
89
+ block_out = ch * ch_mult[0]
90
+ for i_level in range(self.num_blocks):
91
+ block = nn.ModuleList()
92
+ block_in = ch * in_ch_mult[i_level]
93
+ block_out = ch * ch_mult[i_level]
94
+ for _ in range(self.num_res_blocks):
95
+ block.append(ResBlock(block_in, block_out))
96
+ block_in = block_out
97
+
98
+ down = nn.Module()
99
+ down.block = block
100
+ if i_level < self.num_blocks - 1:
101
+ down.downsample = nn.Conv2d(block_out, block_out, kernel_size=3, stride=2, padding=1)
102
+ self.down.append(down)
103
+
104
+ self.mid_block = nn.ModuleList([ResBlock(block_out, block_out) for _ in range(self.num_res_blocks)])
105
+ self.norm_out = nn.GroupNorm(32, block_out, eps=1e-6)
106
+ self.conv_out = nn.Conv2d(block_out, z_channels, kernel_size=1)
107
+
108
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
109
+ x = self.conv_in(x)
110
+ for i_level in range(self.num_blocks):
111
+ for i_block in range(self.num_res_blocks):
112
+ x = self.down[i_level].block[i_block](x)
113
+ if i_level < self.num_blocks - 1:
114
+ x = self.down[i_level].downsample(x)
115
+
116
+ for block in self.mid_block:
117
+ x = block(x)
118
+
119
+ x = self.norm_out(x)
120
+ x = swish(x)
121
+ x = self.conv_out(x)
122
+ return x
123
+
124
+
125
+ def depth_to_space(x: torch.Tensor, block_size: int) -> torch.Tensor:
126
+ if x.dim() < 3:
127
+ raise ValueError("Expected a channels-first (*CHW) tensor of at least 3 dims.")
128
+ c, h, w = x.shape[-3:]
129
+ s = block_size**2
130
+ if c % s != 0:
131
+ raise ValueError(f"Expected C divisible by {s}, but got C={c}.")
132
+
133
+ outer_dims = x.shape[:-3]
134
+ x = x.view(-1, block_size, block_size, c // s, h, w)
135
+ x = x.permute(0, 3, 4, 1, 5, 2)
136
+ x = x.contiguous().view(*outer_dims, c // s, h * block_size, w * block_size)
137
+ return x
138
+
139
+
140
+ class Upsampler(nn.Module):
141
+ def __init__(self, dim: int) -> None:
142
+ super().__init__()
143
+ self.conv1 = nn.Conv2d(dim, dim * 4, kernel_size=3, padding=1)
144
+
145
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
146
+ return depth_to_space(self.conv1(x), block_size=2)
147
+
148
+
149
+ class AdaptiveGroupNorm(nn.Module):
150
+ def __init__(self, z_channel: int, in_filters: int, num_groups: int = 32, eps: float = 1e-6) -> None:
151
+ super().__init__()
152
+ self.gn = nn.GroupNorm(num_groups=num_groups, num_channels=in_filters, eps=eps, affine=False)
153
+ self.gamma = nn.Linear(z_channel, in_filters)
154
+ self.beta = nn.Linear(z_channel, in_filters)
155
+ self.eps = eps
156
+
157
+ def forward(self, x: torch.Tensor, quantizer: torch.Tensor) -> torch.Tensor:
158
+ bsz, channels, _, _ = x.shape
159
+
160
+ scale = rearrange(quantizer, "b c h w -> b c (h w)")
161
+ scale = scale.var(dim=-1) + self.eps
162
+ scale = scale.sqrt()
163
+ scale = self.gamma(scale).view(bsz, channels, 1, 1)
164
+
165
+ bias = rearrange(quantizer, "b c h w -> b c (h w)")
166
+ bias = bias.mean(dim=-1)
167
+ bias = self.beta(bias).view(bsz, channels, 1, 1)
168
+
169
+ x = self.gn(x)
170
+ return scale * x + bias
171
+
172
+
173
+ class Decoder(nn.Module):
174
+ def __init__(
175
+ self,
176
+ *,
177
+ ch: int,
178
+ out_ch: int,
179
+ in_channels: int,
180
+ num_res_blocks: int,
181
+ z_channels: int,
182
+ ch_mult: Sequence[int] = (1, 2, 2, 4),
183
+ resolution: Optional[int] = None,
184
+ double_z: bool = False,
185
+ ) -> None:
186
+ super().__init__()
187
+ del in_channels, resolution, double_z
188
+ self.num_blocks = len(ch_mult)
189
+ self.num_res_blocks = num_res_blocks
190
+
191
+ block_in = ch * ch_mult[self.num_blocks - 1]
192
+ self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, padding=1, bias=True)
193
+ self.mid_block = nn.ModuleList([ResBlock(block_in, block_in) for _ in range(self.num_res_blocks)])
194
+
195
+ self.up = nn.ModuleList()
196
+ self.adaptive = nn.ModuleList()
197
+ for i_level in reversed(range(self.num_blocks)):
198
+ block = nn.ModuleList()
199
+ block_out = ch * ch_mult[i_level]
200
+ self.adaptive.insert(0, AdaptiveGroupNorm(z_channels, block_in))
201
+ for _ in range(self.num_res_blocks):
202
+ block.append(ResBlock(block_in, block_out))
203
+ block_in = block_out
204
+ up = nn.Module()
205
+ up.block = block
206
+ if i_level > 0:
207
+ up.upsample = Upsampler(block_in)
208
+ self.up.insert(0, up)
209
+
210
+ self.norm_out = nn.GroupNorm(32, block_in, eps=1e-6)
211
+ self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, padding=1)
212
+
213
+ def forward(self, z: torch.Tensor) -> torch.Tensor:
214
+ style = z.clone()
215
+ z = self.conv_in(z)
216
+
217
+ for block in self.mid_block:
218
+ z = block(z)
219
+
220
+ for i_level in reversed(range(self.num_blocks)):
221
+ z = self.adaptive[i_level](z, style)
222
+ for i_block in range(self.num_res_blocks):
223
+ z = self.up[i_level].block[i_block](z)
224
+ if i_level > 0:
225
+ z = self.up[i_level].upsample(z)
226
+
227
+ z = self.norm_out(z)
228
+ z = swish(z)
229
+ z = self.conv_out(z)
230
+ return z
231
+
232
+
233
+ class GANDecoder(nn.Module):
234
+ def __init__(
235
+ self,
236
+ *,
237
+ ch: int,
238
+ out_ch: int,
239
+ in_channels: int,
240
+ num_res_blocks: int,
241
+ z_channels: int,
242
+ ch_mult: Sequence[int] = (1, 2, 2, 4),
243
+ resolution: Optional[int] = None,
244
+ double_z: bool = False,
245
+ ) -> None:
246
+ super().__init__()
247
+ del in_channels, resolution, double_z
248
+ self.num_blocks = len(ch_mult)
249
+ self.num_res_blocks = num_res_blocks
250
+
251
+ block_in = ch * ch_mult[self.num_blocks - 1]
252
+ self.conv_in = nn.Conv2d(z_channels * 2, block_in, kernel_size=3, padding=1, bias=True)
253
+ self.mid_block = nn.ModuleList([ResBlock(block_in, block_in) for _ in range(self.num_res_blocks)])
254
+
255
+ self.up = nn.ModuleList()
256
+ self.adaptive = nn.ModuleList()
257
+ for i_level in reversed(range(self.num_blocks)):
258
+ block = nn.ModuleList()
259
+ block_out = ch * ch_mult[i_level]
260
+ self.adaptive.insert(0, AdaptiveGroupNorm(z_channels, block_in))
261
+ for _ in range(self.num_res_blocks):
262
+ block.append(ResBlock(block_in, block_out))
263
+ block_in = block_out
264
+ up = nn.Module()
265
+ up.block = block
266
+ if i_level > 0:
267
+ up.upsample = Upsampler(block_in)
268
+ self.up.insert(0, up)
269
+
270
+ self.norm_out = nn.GroupNorm(32, block_in, eps=1e-6)
271
+ self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, padding=1)
272
+
273
+ def forward(self, z: torch.Tensor) -> torch.Tensor:
274
+ style = z.clone()
275
+ noise = torch.randn_like(z, device=z.device)
276
+ z = torch.cat([z, noise], dim=1)
277
+ z = self.conv_in(z)
278
+
279
+ for block in self.mid_block:
280
+ z = block(z)
281
+
282
+ for i_level in reversed(range(self.num_blocks)):
283
+ z = self.adaptive[i_level](z, style)
284
+ for i_block in range(self.num_res_blocks):
285
+ z = self.up[i_level].block[i_block](z)
286
+ if i_level > 0:
287
+ z = self.up[i_level].upsample(z)
288
+
289
+ z = self.norm_out(z)
290
+ z = swish(z)
291
+ z = self.conv_out(z)
292
+ return z
293
+
294
+
295
+ class BitDanceAutoencoder(ModelMixin, ConfigMixin):
296
+ @register_to_config
297
+ def __init__(self, ddconfig: Dict[str, Any], gan_decoder: bool = False) -> None:
298
+ super().__init__()
299
+ self.encoder = Encoder(**ddconfig)
300
+ self.decoder = GANDecoder(**ddconfig) if gan_decoder else Decoder(**ddconfig)
301
+
302
+ @property
303
+ def z_channels(self) -> int:
304
+ return int(self.config.ddconfig["z_channels"])
305
+
306
+ @property
307
+ def patch_size(self) -> int:
308
+ ch_mult = self.config.ddconfig["ch_mult"]
309
+ return 2 ** (len(ch_mult) - 1)
310
+
311
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
312
+ h = self.encoder(x)
313
+ codebook_value = torch.tensor([1.0], device=h.device, dtype=h.dtype)
314
+ quant_h = torch.where(h > 0, codebook_value, -codebook_value)
315
+ return quant_h
316
+
317
+ def decode(self, quant: torch.Tensor) -> torch.Tensor:
318
+ return self.decoder(quant)
319
+
320
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
321
+ quant = self.encode(x)
322
+ return self.decode(quant)