BiliSakura commited on
Commit
757c1f0
·
verified ·
1 Parent(s): 5991d46

Update all files for BitDance-ImageNet-diffusers

Browse files
BitDance_B_16x/transformer/diff_head.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from torch.utils.checkpoint import checkpoint
7
+
8
+ from .sampling import euler_maruyama
9
+
10
+
11
+ def timestep_embedding(t, dim, max_period=10000, time_factor: float = 1000.0):
12
+ half = dim // 2
13
+ t = time_factor * t.float()
14
+ freqs = torch.exp(
15
+ -math.log(max_period)
16
+ * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device)
17
+ / half
18
+ )
19
+
20
+ args = t[:, None] * freqs[None]
21
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
22
+ if dim % 2:
23
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
24
+ if torch.is_floating_point(t):
25
+ embedding = embedding.to(t)
26
+ return embedding
27
+
28
+ def time_shift_sana(t: torch.Tensor, flow_shift: float = 1., sigma: float = 1.):
29
+ return (1 / flow_shift) / ( (1 / flow_shift) + (1 / t - 1) ** sigma)
30
+
31
+ class DiffHead(nn.Module):
32
+ """Diffusion Loss"""
33
+
34
+ def __init__(
35
+ self,
36
+ ch_target,
37
+ ch_cond,
38
+ ch_latent,
39
+ depth_latent,
40
+ depth_adanln,
41
+ grad_checkpointing=False,
42
+ time_shift=1.,
43
+ time_schedule='logit_normal',
44
+ P_std: float = 1.,
45
+ P_mean: float = 0.,
46
+ ):
47
+ super(DiffHead, self).__init__()
48
+ self.ch_target = ch_target
49
+ self.time_shift = time_shift
50
+ self.time_schedule = time_schedule
51
+ self.P_std = P_std
52
+ self.P_mean = P_mean
53
+
54
+ self.net = MlpEncoder(
55
+ in_channels=ch_target,
56
+ model_channels=ch_latent,
57
+ z_channels=ch_cond,
58
+ num_res_blocks=depth_latent,
59
+ num_ada_ln_blocks=depth_adanln,
60
+ grad_checkpointing=grad_checkpointing,
61
+ )
62
+
63
+ def forward(self, x, cond):
64
+ with torch.autocast(device_type="cuda", enabled=False):
65
+ with torch.no_grad():
66
+ if self.time_schedule == 'logit_normal':
67
+ t = (torch.randn((x.shape[0]), device=x.device) * self.P_std + self.P_mean).sigmoid()
68
+ if self.time_shift != 1.:
69
+ t = time_shift_sana(t, self.time_shift)
70
+ elif self.time_schedule == 'uniform':
71
+ t = torch.rand((x.shape[0]), device=x.device)
72
+ if self.time_shift != 1.:
73
+ t = time_shift_sana(t, self.time_shift)
74
+ else:
75
+ raise NotImplementedError(f"unknown time_schedule {self.time_schedule}")
76
+ e = torch.randn_like(x)
77
+ ti = t.view(-1, 1)
78
+ z = (1.0 - ti) * e + ti * x
79
+ v = (x - z) / (1 - ti).clamp_min(0.05)
80
+
81
+ x_pred = self.net(z, t, cond)
82
+ v_pred = (x_pred - z) / (1 - ti).clamp_min(0.05)
83
+
84
+ with torch.autocast(device_type="cuda", enabled=False):
85
+ v_pred = v_pred.float()
86
+ loss = torch.mean((v - v_pred) ** 2)
87
+ return loss
88
+
89
+ def sample(
90
+ self,
91
+ z,
92
+ cfg,
93
+ num_sampling_steps,
94
+ ):
95
+ return euler_maruyama(
96
+ self.ch_target,
97
+ self.net.forward,
98
+ z,
99
+ cfg,
100
+ num_sampling_steps=num_sampling_steps,
101
+ time_shift = self.time_shift,
102
+ )
103
+
104
+ def initialize_weights(self):
105
+ self.net.initialize_weights()
106
+
107
+
108
+ class TimestepEmbedder(nn.Module):
109
+ """
110
+ Embeds scalar timesteps into vector representations.
111
+ """
112
+
113
+ def __init__(self, hidden_size, frequency_embedding_size=256):
114
+ super().__init__()
115
+ self.mlp = nn.Sequential(
116
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
117
+ nn.SiLU(),
118
+ nn.Linear(hidden_size, hidden_size, bias=True),
119
+ )
120
+ self.frequency_embedding_size = frequency_embedding_size
121
+
122
+ def forward(self, t):
123
+ t_freq = timestep_embedding(t, self.frequency_embedding_size)
124
+ t_emb = self.mlp(t_freq)
125
+ return t_emb
126
+
127
+
128
+ class ResBlock(nn.Module):
129
+ def __init__(self, channels):
130
+ super().__init__()
131
+ self.channels = channels
132
+ self.norm = nn.LayerNorm(channels, eps=1e-6, elementwise_affine=True)
133
+ hidden_dim = int(channels * 1.5)
134
+ self.w1 = nn.Linear(channels, hidden_dim * 2, bias=True)
135
+ self.w2 = nn.Linear(hidden_dim, channels, bias=True)
136
+
137
+ def forward(self, x, scale, shift, gate):
138
+ h = self.norm(x) * (1 + scale) + shift
139
+ h1, h2 = self.w1(h).chunk(2, dim=-1)
140
+ h = self.w2(F.silu(h1) * h2)
141
+ return x + h * gate
142
+
143
+
144
+ class FinalLayer(nn.Module):
145
+ def __init__(self, channels, out_channels):
146
+ super().__init__()
147
+ self.norm_final = nn.LayerNorm(channels, eps=1e-6, elementwise_affine=False)
148
+ self.ada_ln_modulation = nn.Linear(channels, channels * 2, bias=True)
149
+ self.linear = nn.Linear(channels, out_channels, bias=True)
150
+
151
+ def forward(self, x, y):
152
+ scale, shift = self.ada_ln_modulation(y).chunk(2, dim=-1)
153
+ x = self.norm_final(x) * (1.0 + scale) + shift
154
+ x = self.linear(x)
155
+ return x
156
+
157
+
158
+ class MlpEncoder(nn.Module):
159
+
160
+ def __init__(
161
+ self,
162
+ in_channels,
163
+ model_channels,
164
+ z_channels,
165
+ num_res_blocks,
166
+ num_ada_ln_blocks=2,
167
+ grad_checkpointing=False,
168
+ ):
169
+ super().__init__()
170
+
171
+ self.in_channels = in_channels
172
+ self.model_channels = model_channels
173
+ self.out_channels = in_channels
174
+ self.num_res_blocks = num_res_blocks
175
+ self.grad_checkpointing = grad_checkpointing
176
+
177
+ self.time_embed = TimestepEmbedder(model_channels)
178
+ self.cond_embed = nn.Linear(z_channels, model_channels)
179
+
180
+ self.input_proj = nn.Linear(in_channels, model_channels)
181
+ self.res_blocks = nn.ModuleList()
182
+ for i in range(num_res_blocks):
183
+ self.res_blocks.append(
184
+ ResBlock(
185
+ model_channels,
186
+ )
187
+ )
188
+ # share adaLN for consecutive blocks, to save computation and parameters
189
+ self.ada_ln_blocks = nn.ModuleList()
190
+ for i in range(num_ada_ln_blocks):
191
+ self.ada_ln_blocks.append(
192
+ nn.Linear(model_channels, model_channels * 3, bias=True)
193
+ )
194
+ self.ada_ln_switch_freq = max(1, num_res_blocks // num_ada_ln_blocks)
195
+ assert (
196
+ num_res_blocks % self.ada_ln_switch_freq
197
+ ) == 0, "num_res_blocks must be divisible by num_ada_ln_blocks"
198
+ self.final_layer = FinalLayer(model_channels, self.out_channels)
199
+
200
+ self.initialize_weights()
201
+
202
+ def initialize_weights(self):
203
+ def _basic_init(module):
204
+ if isinstance(module, nn.Linear):
205
+ torch.nn.init.xavier_uniform_(module.weight)
206
+ if module.bias is not None:
207
+ nn.init.constant_(module.bias, 0)
208
+
209
+ self.apply(_basic_init)
210
+
211
+ # Initialize timestep embedding MLP
212
+ nn.init.normal_(self.time_embed.mlp[0].weight, std=0.02)
213
+ nn.init.normal_(self.time_embed.mlp[2].weight, std=0.02)
214
+
215
+ for block in self.ada_ln_blocks:
216
+ nn.init.constant_(block.weight, 0)
217
+ nn.init.constant_(block.bias, 0)
218
+
219
+ # Zero-out output layers
220
+ nn.init.constant_(self.final_layer.ada_ln_modulation.weight, 0)
221
+ nn.init.constant_(self.final_layer.ada_ln_modulation.bias, 0)
222
+ nn.init.constant_(self.final_layer.linear.weight, 0)
223
+ nn.init.constant_(self.final_layer.linear.bias, 0)
224
+
225
+ @torch.compile()
226
+ def forward(self, x, t, c):
227
+ """
228
+ Apply the model to an input batch.
229
+ :param x: an [N x C] Tensor of inputs.
230
+ :param t: a 1-D batch of timesteps.
231
+ :param c: conditioning from AR transformer.
232
+ :return: an [N x C] Tensor of outputs.
233
+ """
234
+ x = self.input_proj(x)
235
+ t = self.time_embed(t)
236
+ c = self.cond_embed(c)
237
+
238
+ y = F.silu(t + c)
239
+ scale, shift, gate = self.ada_ln_blocks[0](y).chunk(3, dim=-1)
240
+ if self.grad_checkpointing and self.training:
241
+ for i, block in enumerate(self.res_blocks):
242
+ if i > 0 and i % self.ada_ln_switch_freq == 0:
243
+ ada_ln_block = self.ada_ln_blocks[i // self.ada_ln_switch_freq]
244
+ scale, shift, gate = ada_ln_block(y).chunk(3, dim=-1)
245
+ x = checkpoint(block, x, scale, shift, gate, use_reentrant=False)
246
+ else:
247
+ for i, block in enumerate(self.res_blocks):
248
+ if i > 0 and i % self.ada_ln_switch_freq == 0:
249
+ ada_ln_block = self.ada_ln_blocks[i // self.ada_ln_switch_freq]
250
+ scale, shift, gate = ada_ln_block(y).chunk(3, dim=-1)
251
+ x = block(x, scale, shift, gate)
252
+
253
+ return self.final_layer(x, y)