BiliSakura commited on
Commit
ddac775
·
verified ·
1 Parent(s): 4ec11b6

Update all files for BitDance-ImageNet-diffusers

Browse files
Files changed (1) hide show
  1. BitDance_B_1x/transformer/sampling.py +119 -0
BitDance_B_1x/transformer/sampling.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ def time_shift_sana(t: torch.Tensor, flow_shift: float = 1., sigma: float = 1.):
4
+ return (1 / flow_shift) / ( (1 / flow_shift) + (1 / t - 1) ** sigma)
5
+
6
+ def get_score_from_velocity(velocity, x, t):
7
+ alpha_t, d_alpha_t = t, 1
8
+ sigma_t, d_sigma_t = 1 - t, -1
9
+ mean = x
10
+ reverse_alpha_ratio = alpha_t / d_alpha_t
11
+ var = sigma_t**2 - reverse_alpha_ratio * d_sigma_t * sigma_t
12
+ score = (reverse_alpha_ratio * velocity - mean) / var
13
+ return score
14
+
15
+
16
+ def get_velocity_from_cfg(velocity, cfg, cfg_mult):
17
+ if cfg_mult == 2:
18
+ cond_v, uncond_v = torch.chunk(velocity, 2, dim=0)
19
+ velocity = uncond_v + cfg * (cond_v - uncond_v)
20
+ return velocity
21
+
22
+
23
+ @torch.compile()
24
+ def euler_step(x, v, dt: float, cfg: float, cfg_mult: int):
25
+ with torch.amp.autocast("cuda", enabled=False):
26
+ v = v.to(torch.float32)
27
+ v = get_velocity_from_cfg(v, cfg, cfg_mult)
28
+ x = x + v * dt
29
+ return x
30
+
31
+
32
+ @torch.compile()
33
+ def euler_maruyama_step(x, v, t, dt: float, cfg: float, cfg_mult: int):
34
+ with torch.amp.autocast("cuda", enabled=False):
35
+ v = v.to(torch.float32)
36
+ v = get_velocity_from_cfg(v, cfg, cfg_mult)
37
+ score = get_score_from_velocity(v, x, t)
38
+ drift = v + (1 - t) * score
39
+ noise_scale = (2.0 * (1.0 - t) * dt) ** 0.5
40
+ x = x + drift * dt + noise_scale * torch.randn_like(x)
41
+ return x
42
+
43
+
44
+ def euler_maruyama(
45
+ input_dim,
46
+ forward_fn,
47
+ c: torch.Tensor,
48
+ cfg: float = 1.0,
49
+ num_sampling_steps: int = 20,
50
+ last_step_size: float = 0.05,
51
+ time_shift: float = 1.,
52
+ ):
53
+ cfg_mult = 1
54
+ if cfg > 1.0:
55
+ cfg_mult += 1
56
+
57
+ x_shape = list(c.shape)
58
+ x_shape[0] = x_shape[0] // cfg_mult
59
+ x_shape[-1] = input_dim
60
+ x = torch.randn(x_shape, device=c.device)
61
+ # an = (1.0 - last_step_size) / num_sampling_steps
62
+ t_all = torch.linspace(0, 1-last_step_size, num_sampling_steps+1, device=c.device, dtype=torch.float32)
63
+ t_all = time_shift_sana(t_all, time_shift)
64
+ dt = t_all[1:] - t_all[:-1]
65
+ t = torch.tensor(
66
+ 0.0, device=c.device, dtype=torch.float32
67
+ ) # use tensor to avoid compile warning
68
+ t_batch = torch.zeros(c.shape[0], device=c.device)
69
+ for i in range(num_sampling_steps):
70
+ t_batch[:] = t
71
+ combined = torch.cat([x] * cfg_mult, dim=0)
72
+ output = forward_fn(
73
+ combined,
74
+ t_batch,
75
+ c,
76
+ )
77
+ v = (output - combined) / (1 - t_batch.view(-1, 1)).clamp_min(0.05)
78
+ x = euler_maruyama_step(x, v, t, dt[i], cfg, cfg_mult)
79
+ t += dt[i]
80
+
81
+ combined = torch.cat([x] * cfg_mult, dim=0)
82
+ t_batch[:] = 1 - last_step_size
83
+ output = forward_fn(
84
+ combined,
85
+ t_batch,
86
+ c,
87
+ )
88
+ v = (output - combined) / (1 - t_batch.view(-1, 1)).clamp_min(0.05)
89
+ x = euler_step(x, v, last_step_size, cfg, cfg_mult)
90
+
91
+ return torch.cat([x] * cfg_mult, dim=0)
92
+
93
+
94
+ def euler(
95
+ input_dim,
96
+ forward_fn,
97
+ c,
98
+ cfg: float = 1.0,
99
+ num_sampling_steps: int = 50,
100
+ ):
101
+ cfg_mult = 1
102
+ if cfg > 1.0:
103
+ cfg_mult = 2
104
+
105
+ x_shape = list(c.shape)
106
+ x_shape[0] = x_shape[0] // cfg_mult
107
+ x_shape[-1] = input_dim
108
+ x = torch.randn(x_shape, device=c.device)
109
+ dt = 1.0 / num_sampling_steps
110
+ t = 0
111
+ t_batch = torch.zeros(c.shape[0], device=c.device)
112
+ for _ in range(num_sampling_steps):
113
+ t_batch[:] = t
114
+ combined = torch.cat([x] * cfg_mult, dim=0)
115
+ v = forward_fn(combined, t_batch, c)
116
+ x = euler_step(x, v, dt, cfg, cfg_mult)
117
+ t += dt
118
+
119
+ return torch.cat([x] * cfg_mult, dim=0)