BiliSakura commited on
Commit
3f7ea9b
·
verified ·
1 Parent(s): 5b41b52

Update all files for BitDance-ImageNet-diffusers

Browse files
BitDance_H_1x/transformer/src/diff_head_parallel.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ from .sampling_parallel import euler_maruyama
8
+
9
+
10
+ def timestep_embedding(t, dim, max_period=10000, time_factor: float = 1000.0):
11
+ half = dim // 2
12
+ t = time_factor * t.float()
13
+ freqs = torch.exp(
14
+ -math.log(max_period)
15
+ * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device)
16
+ / half
17
+ )
18
+
19
+ args = t[:, None] * freqs[None]
20
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
21
+ if dim % 2:
22
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
23
+ if torch.is_floating_point(t):
24
+ embedding = embedding.to(t)
25
+ return embedding
26
+
27
+ def time_shift_sana(t: torch.Tensor, flow_shift: float = 1., sigma: float = 1.):
28
+ return (1 / flow_shift) / ( (1 / flow_shift) + (1 / t - 1) ** sigma)
29
+
30
+ class DiffHead(nn.Module):
31
+ """Diffusion Loss"""
32
+
33
+ def __init__(
34
+ self,
35
+ ch_target,
36
+ ch_cond,
37
+ ch_latent,
38
+ depth_latent,
39
+ depth_adanln,
40
+ grad_checkpointing=False,
41
+ time_shift=1.,
42
+ time_schedule='logit_normal',
43
+ parallel_num=4,
44
+ P_std: float = 1.,
45
+ P_mean: float = 0.,
46
+ ):
47
+ super(DiffHead, self).__init__()
48
+ self.ch_target = ch_target
49
+ self.time_shift = time_shift
50
+ self.time_schedule = time_schedule
51
+ self.P_std = P_std
52
+ self.P_mean = P_mean
53
+
54
+ self.net = TransEncoder(
55
+ in_channels=ch_target,
56
+ model_channels=ch_latent,
57
+ z_channels=ch_cond,
58
+ num_res_blocks=depth_latent,
59
+ num_ada_ln_blocks=depth_adanln,
60
+ grad_checkpointing=grad_checkpointing,
61
+ parallel_num=parallel_num,
62
+ )
63
+
64
+ def forward(self, x, cond):
65
+ with torch.autocast(device_type="cuda", enabled=False):
66
+ with torch.no_grad():
67
+ if self.time_schedule == 'logit_normal':
68
+ t = (torch.randn((x.shape[0]), device=x.device) * self.P_std + self.P_mean).sigmoid()
69
+ if self.time_shift != 1.:
70
+ t = time_shift_sana(t, self.time_shift)
71
+ elif self.time_schedule == 'uniform':
72
+ t = torch.rand((x.shape[0]), device=x.device)
73
+ if self.time_shift != 1.:
74
+ t = time_shift_sana(t, self.time_shift)
75
+ else:
76
+ raise NotImplementedError(f"unknown time_schedule {self.time_schedule}")
77
+ e = torch.randn_like(x)
78
+ ti = t.view(-1, 1, 1)
79
+ z = (1.0 - ti) * e + ti * x
80
+ v = (x - z) / (1 - ti).clamp_min(0.05)
81
+
82
+ x_pred = self.net(z, t, cond)
83
+ v_pred = (x_pred - z) / (1 - ti).clamp_min(0.05)
84
+
85
+ with torch.autocast(device_type="cuda", enabled=False):
86
+ v_pred = v_pred.float()
87
+ loss = torch.mean((v - v_pred) ** 2)
88
+ return loss
89
+
90
+ def sample(
91
+ self,
92
+ z,
93
+ cfg,
94
+ num_sampling_steps,
95
+ ):
96
+ return euler_maruyama(
97
+ self.ch_target,
98
+ self.net.forward,
99
+ z,
100
+ cfg,
101
+ num_sampling_steps=num_sampling_steps,
102
+ time_shift = self.time_shift,
103
+ )
104
+
105
+ def initialize_weights(self):
106
+ self.net.initialize_weights()
107
+
108
+
109
+ class TimestepEmbedder(nn.Module):
110
+ """
111
+ Embeds scalar timesteps into vector representations.
112
+ """
113
+
114
+ def __init__(self, hidden_size, frequency_embedding_size=256):
115
+ super().__init__()
116
+ self.mlp = nn.Sequential(
117
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
118
+ nn.SiLU(),
119
+ nn.Linear(hidden_size, hidden_size, bias=True),
120
+ )
121
+ self.frequency_embedding_size = frequency_embedding_size
122
+
123
+ def forward(self, t):
124
+ t_freq = timestep_embedding(t, self.frequency_embedding_size)
125
+ t_emb = self.mlp(t_freq)
126
+ return t_emb
127
+
128
+
129
+ class ResBlock(nn.Module):
130
+ def __init__(self, channels):
131
+ super().__init__()
132
+ self.channels = channels
133
+ self.norm = nn.LayerNorm(channels, eps=1e-6, elementwise_affine=True)
134
+ hidden_dim = int(channels * 1.5)
135
+ self.w1 = nn.Linear(channels, hidden_dim * 2, bias=True)
136
+ self.w2 = nn.Linear(hidden_dim, channels, bias=True)
137
+
138
+ def forward(self, x, scale, shift, gate):
139
+ h = self.norm(x) * (1 + scale) + shift
140
+ h1, h2 = self.w1(h).chunk(2, dim=-1)
141
+ h = self.w2(F.silu(h1) * h2)
142
+ return x + h * gate
143
+
144
+
145
+ class FinalLayer(nn.Module):
146
+ def __init__(self, channels, out_channels):
147
+ super().__init__()
148
+ self.norm_final = nn.LayerNorm(channels, eps=1e-6, elementwise_affine=False)
149
+ self.ada_ln_modulation = nn.Linear(channels, channels * 2, bias=True)
150
+ self.linear = nn.Linear(channels, out_channels, bias=True)
151
+
152
+ def forward(self, x, y):
153
+ scale, shift = self.ada_ln_modulation(y).chunk(2, dim=-1)
154
+ x = self.norm_final(x) * (1.0 + scale) + shift
155
+ x = self.linear(x)
156
+ return x
157
+
158
+ class Attention(nn.Module):
159
+ def __init__(
160
+ self,
161
+ dim,
162
+ n_head,
163
+ ):
164
+ super().__init__()
165
+ assert dim % n_head == 0
166
+ self.dim = dim
167
+ self.head_dim = dim // n_head
168
+ self.scale = self.head_dim**-0.5
169
+ self.n_head = n_head
170
+ total_kv_dim = (self.n_head * 3) * self.head_dim
171
+
172
+ self.wqkv = nn.Linear(dim, total_kv_dim, bias=True)
173
+ self.wo = nn.Linear(dim, dim, bias=True)
174
+
175
+ def forward(
176
+ self,
177
+ x: torch.Tensor,
178
+ ):
179
+ bsz, seqlen, _ = x.shape
180
+ xq, xk, xv = self.wqkv(x).chunk(3, dim=-1)
181
+
182
+ xq = xq.view(bsz, seqlen, self.n_head, self.head_dim)
183
+ xk = xk.view(bsz, seqlen, self.n_head, self.head_dim)
184
+ xv = xv.view(bsz, seqlen, self.n_head, self.head_dim)
185
+
186
+
187
+ xq, xk, xv = map(lambda x: x.transpose(1, 2), (xq, xk, xv))
188
+ xq = xq * self.scale
189
+ attn = xq @ xk.transpose(-1, -2)
190
+ attn = F.softmax(attn, dim=-1)
191
+ output = (attn @ xv).transpose(1, 2).contiguous()
192
+
193
+ # output = flash_attn_func(
194
+ # xq,
195
+ # xk,
196
+ # xv,
197
+ # causal=False,
198
+ # )
199
+
200
+ output = output.view(bsz, seqlen, self.dim)
201
+
202
+ output = self.wo(output)
203
+ return output
204
+
205
+ class TransBlock(nn.Module):
206
+ def __init__(self, channels):
207
+ super().__init__()
208
+ self.channels = channels
209
+ self.norm1 = nn.LayerNorm(channels, eps=1e-6, elementwise_affine=True)
210
+ self.attn = Attention(channels, n_head=channels//64)
211
+
212
+ self.norm2 = nn.LayerNorm(channels, eps=1e-6, elementwise_affine=True)
213
+ hidden_dim = int(channels * 1.5)
214
+ self.w1 = nn.Linear(channels, hidden_dim * 2, bias=True)
215
+ self.w2 = nn.Linear(hidden_dim, channels, bias=True)
216
+
217
+ def forward(self, x, scale1, shift1, gate1, scale2, shift2, gate2):
218
+ h = self.norm1(x) * (1 + scale1) + shift1
219
+ h = self.attn(h)
220
+ x = x + h * gate1
221
+ h = self.norm2(x) * (1 + scale2) + shift2
222
+ h1, h2 = self.w1(h).chunk(2, dim=-1)
223
+ h = self.w2(F.silu(h1) * h2)
224
+ return x + h * gate2
225
+
226
+ class TransEncoder(nn.Module):
227
+
228
+ def __init__(
229
+ self,
230
+ in_channels,
231
+ model_channels,
232
+ z_channels,
233
+ num_res_blocks,
234
+ num_ada_ln_blocks=2,
235
+ grad_checkpointing=False,
236
+ parallel_num=4,
237
+ ):
238
+ super().__init__()
239
+
240
+ self.in_channels = in_channels
241
+ self.model_channels = model_channels
242
+ self.out_channels = in_channels
243
+ self.num_res_blocks = num_res_blocks
244
+ self.grad_checkpointing = grad_checkpointing
245
+ self.parallel_num = parallel_num
246
+
247
+ self.time_embed = TimestepEmbedder(model_channels)
248
+ self.cond_embed = nn.Linear(z_channels, model_channels)
249
+
250
+ self.input_proj = nn.Linear(in_channels, model_channels)
251
+ self.res_blocks = nn.ModuleList()
252
+ for i in range(num_res_blocks):
253
+ self.res_blocks.append(
254
+ TransBlock(
255
+ model_channels,
256
+ )
257
+ )
258
+ # share adaLN for consecutive blocks, to save computation and parameters
259
+ self.ada_ln_blocks = nn.ModuleList()
260
+ for i in range(num_ada_ln_blocks):
261
+ self.ada_ln_blocks.append(
262
+ nn.Linear(model_channels, model_channels * 6, bias=True)
263
+ )
264
+ self.ada_ln_switch_freq = max(1, num_res_blocks // num_ada_ln_blocks)
265
+ assert (
266
+ num_res_blocks % self.ada_ln_switch_freq
267
+ ) == 0, "num_res_blocks must be divisible by num_ada_ln_blocks"
268
+ self.final_layer = FinalLayer(model_channels, self.out_channels)
269
+
270
+ self.initialize_weights()
271
+
272
+ def initialize_weights(self):
273
+ def _basic_init(module):
274
+ if isinstance(module, nn.Linear):
275
+ torch.nn.init.xavier_uniform_(module.weight)
276
+ if module.bias is not None:
277
+ nn.init.constant_(module.bias, 0)
278
+
279
+ self.apply(_basic_init)
280
+
281
+ # Initialize timestep embedding MLP
282
+ nn.init.normal_(self.time_embed.mlp[0].weight, std=0.02)
283
+ nn.init.normal_(self.time_embed.mlp[2].weight, std=0.02)
284
+
285
+ for block in self.ada_ln_blocks:
286
+ nn.init.constant_(block.weight, 0)
287
+ nn.init.constant_(block.bias, 0)
288
+
289
+ # Zero-out output layers
290
+ nn.init.constant_(self.final_layer.ada_ln_modulation.weight, 0)
291
+ nn.init.constant_(self.final_layer.ada_ln_modulation.bias, 0)
292
+ nn.init.constant_(self.final_layer.linear.weight, 0)
293
+ nn.init.constant_(self.final_layer.linear.bias, 0)
294
+
295
+ @torch.compile()
296
+ def forward(self, x, t, c):
297
+ x = self.input_proj(x)
298
+ t = self.time_embed(t).unsqueeze(1)
299
+ c = self.cond_embed(c)
300
+
301
+ y = F.silu(t+c)
302
+ scale1, shift1, gate1, scale2, shift2, gate2 = self.ada_ln_blocks[0](y).chunk(6, dim=-1)
303
+
304
+ for i, block in enumerate(self.res_blocks):
305
+ if i > 0 and i % self.ada_ln_switch_freq == 0:
306
+ ada_ln_block = self.ada_ln_blocks[i // self.ada_ln_switch_freq]
307
+ scale1, shift1, gate1, scale2, shift2, gate2 = ada_ln_block(y).chunk(6, dim=-1)
308
+ x = block(x, scale1, shift1, gate1, scale2, shift2, gate2)
309
+
310
+ return self.final_layer(x, y)